TensorFlow: получение всех состояний из RNN

Как вы получаете все скрытые состояния от tf.nn.rnn() или tf.nn.dynamic_rnn() в TensorFlow? API дает мне только конечное состояние.

Первым вариантом было бы написать цикл при построении модели, которая работает непосредственно на RNNCell. Однако количество временных меток не фиксировано для меня и зависит от входящей партии.

Некоторые параметры - использовать GRU или написать собственный RNNCell, который объединяет состояние с выходом. Первый выбор не является достаточно общим, и последний кажется слишком хриплым.

Другой вариант - сделать что-то вроде ответов в этом вопросе, получив все переменные из RNN. Однако я не уверен, как здесь отделять скрытые состояния от других переменных стандартным образом.

Есть ли хороший способ получить все скрытые состояния из RNN, все еще используя API RNN, предоставляемые библиотекой?

Ответ 1

tf.nn.dynamic_rnn (также tf.nn.static_rnn) имеет два возвращаемых значения; "выходы" , "состояние" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)

Как вы сказали, "состояние" является конечным состоянием RNN, но "выходы" - это все скрытые состояния RNN (какая форма [batch_size, max_time, cell.output_size])

Вы можете использовать "выходы" в качестве скрытых состояний RNN, потому что в большинстве предоставляемых библиотекой RNNCell "вывод" и "состояние" одинаковы. (кроме LSTMCell)

Ответ 2

Я уже создал PR здесь, и это может помочь вам справиться с простыми случаями

Позвольте мне кратко объяснить мою реализацию, поэтому вы можете написать свою версию, если вам нужно. Основная часть - это модификация функции _time_step:

def _time_step(time, output_ta_t, state, *args):

Параметры остаются неизменными, за исключением того, что передается дополнительная *args. Но почему args? Это потому, что я хочу поддерживать привычное поведение тензорного потока. Вы можете вернуть конечное состояние, просто проигнорировав параметр args:

if states_ta is not None:
    # If you want to return all states, set `args` to be `states_ta`
    loop_vars = (time, output_ta, state, states_ta)
else:
    # If you want the final state only, ignore `args`
    loop_vars = (time, output_ta, state)

Как его использовать?

if args:
    args = tuple(
        ta.write(time, out) for ta, out in zip(args[0], [new_state])
    )

На самом деле это всего лишь модификация следующих (оригинальных) кодов:

output_ta_t = tuple(
    ta.write(time, out) for ta, out in zip(output_ta_t, output)
)

Теперь args должен содержать все состояния, которые вы хотите.

После всех выполненных выше работ вы можете выбрать состояния (или конечное состояние) со следующими кодами:

_, output_final_ta, *state_info = control_flow_ops.while_loop( ...

и

if states_ta is not None:
    final_state, states_final_ta = state_info
else:
    final_state, states_final_ta = state_info[0], None

Хотя я не тестировал его в сложных случаях, он должен работать в "простом" состоянии (вот мои тестовые примеры)