Где next_batch в учебнике TensorFlow batch_xs, batch_ys = mnist.train.next_batch (100)?

Я тестирую учебник TensorFlow и не понимаю, откуда идет next_batch в этой строке?

 batch_xs, batch_ys = mnist.train.next_batch(100)

Я посмотрел

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

И не видел next_batch там.

Теперь, когда вы тестируете next_batch в моем собственном коде, я получаю

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

Итак, я хотел бы понять, откуда происходит next_batch?

Ответ 1

next_batch - это метод класса DataSet (см. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py для получения дополнительной информации о том, что в классе).

Когда вы загружаете данные mnist и назначаете его переменной mnist с помощью:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Посмотрите на класс mnist.train. Вы можете увидеть его, набрав:

print mnist.train.__class__

Вы увидите следующее:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

Поскольку mnist.train является экземпляром класса DataSet, вы можете использовать функцию класса next_batch. Для получения дополнительной информации о классах ознакомьтесь с документацией.

Ответ 2

Просматривая репозиторий tensorflow, он, кажется, возникает здесь:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

Однако, если вы хотите реализовать его в своем собственном коде (для своего собственного набора данных), было бы гораздо проще записать его непосредственно в объекте набора данных, как и я. Насколько я понимаю, это способ перетасовать весь набор данных и вернуть $mini_batch_size количество выборок из перетасованного набора данных.

Здесь некоторый псевдокод:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

Ответ 3

Вы можете просто использовать справочную функцию:

help(tf.contrib.learn.datasets.mnist.DataSet.next_batch)

и получить документ функции next_batch