Я написал модель языка RNN, используя TensorFlow. Модель реализована как класс RNN
. Структура графа встроена в конструктор, а методы RNN.train
и RNN.test
запускают его.
Я хочу иметь возможность reset состояния RNN при переходе к новому документу в наборе обучения или когда я хочу запустить проверку, установленную во время обучения. Я делаю это, управляя состоянием внутри цикла обучения, передавая его в график через словарь фида.
В конструкторе я определяю RNN следующим образом
cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)
Цикл обучения выглядит следующим образом:
for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})
x
и y
- это пакеты данных обучения в документе. Идея состоит в том, что я передаю последнее состояние после каждой партии, за исключением случаев, когда я запускаю новый документ, когда я обнуляю состояние, запустив self.reset_state
.
Все это работает. Теперь я хочу изменить свой RNN, чтобы использовать рекомендуемый state_is_tuple=True
. Однако я не знаю, как передать более сложный объект состояния LSTM через словарь фида. Также я не знаю, какие аргументы передаются в строку self.state = tf.placeholder(...)
в моем конструкторе.
Какая здесь правильная стратегия? Для dynamic_rnn
доступно еще немного кода или документации для примера.
Проблемы с TensorFlow 2695 и 2838 отображаются соответствующие.
A сообщение в блоге на WILDML решает эти проблемы, но прямо не объясняет ответ.
См. также TensorFlow: запомните состояние LSTM для следующей партии (с сохранением состояния LSTM).