Как установить состояние RNN TensorFlow, когда state_is_tuple = True?

Я написал модель языка RNN, используя TensorFlow. Модель реализована как класс RNN. Структура графа встроена в конструктор, а методы RNN.train и RNN.test запускают его.

Я хочу иметь возможность reset состояния RNN при переходе к новому документу в наборе обучения или когда я хочу запустить проверку, установленную во время обучения. Я делаю это, управляя состоянием внутри цикла обучения, передавая его в график через словарь фида.

В конструкторе я определяю RNN следующим образом

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

Цикл обучения выглядит следующим образом:

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x и y - это пакеты данных обучения в документе. Идея состоит в том, что я передаю последнее состояние после каждой партии, за исключением случаев, когда я запускаю новый документ, когда я обнуляю состояние, запустив self.reset_state.

Все это работает. Теперь я хочу изменить свой RNN, чтобы использовать рекомендуемый state_is_tuple=True. Однако я не знаю, как передать более сложный объект состояния LSTM через словарь фида. Также я не знаю, какие аргументы передаются в строку self.state = tf.placeholder(...) в моем конструкторе.

Какая здесь правильная стратегия? Для dynamic_rnn доступно еще немного кода или документации для примера.


Проблемы с TensorFlow 2695 и 2838 отображаются соответствующие.

A сообщение в блоге на WILDML решает эти проблемы, но прямо не объясняет ответ.

См. также TensorFlow: запомните состояние LSTM для следующей партии (с сохранением состояния LSTM).

Ответ 1

Одна проблема с заполнитель Tensorflow заключается в том, что вы можете подавать его только с помощью списка Python или массива Numpy (я думаю). Таким образом, вы не можете сохранить состояние между запусками в кортежах LSTMStateTuple.

Я решил это, сохранив состояние в тензоре, подобном этому

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

У вас есть два компонента в слое LSTM, состояние ячейки и скрытое состояние, вот что происходит от "2". (эта статья замечательная: https://arxiv.org/pdf/1506.00019.pdf)

При создании графика вы распаковываете и создаете состояние кортежа следующим образом:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

Затем вы получаете новое состояние обычным способом

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

Это не должно быть так... возможно, они работают над решением.

Ответ 2

Простой способ подачи в состоянии RNN состоит в том, чтобы просто загружать оба компонента кортежа состояний индивидуально.

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})