Получить последний выход dynamic_rnn в тензорном потоке?

Я использую dynamic_rnn для обработки данных MNIST:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
                         forget_bias=1.0,
                         initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]

Полный код: http://pastebin.com/bhf9MgMe

Вход в lstm равен (batch_size, sequence_length, input_size)

В результате размеры output_at_T равны (batch_size, sequence_length, num_units), где num_units=200.

Мне нужно получить последний результат по размеру sequence_length. В приведенном выше коде это жестко закодировано как 27. Тем не менее, я не знаю sequence_length заранее, так как он может меняться от партии к партии в моем приложении.

Я пробовал:

output_at_T = output[:, -1, :]

но он говорит, что отрицательная индексация еще не реализована, и я попытался использовать переменную-заполнителя, а также константу (в которую я мог бы идеально комбинировать sequence_length для определенной партии); ни работали.

Любой способ реализовать что-то подобное в тензорном потоке atm?

Ответ 1

Вы заметили, что есть два выхода из dynamic_rnn?

  • Выход 1, позвоните ему h, имеет все выходы на каждом шаге (т.е. h_1, h_2 и т.д.),
  • Выход 2, final_state, имеет два элемента: cell_state и последний вывод для каждого элемента пакета (до тех пор, пока вы вводите длину последовательности в dynamic_rnn).

Итак, из:

h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )

последнее состояние для каждого элемента в партии:

final_state.h

Обратите внимание, что это включает случай, когда длина последовательности отличается для каждого элемента пакета, поскольку мы используем аргумент sequence_length.

Ответ 2

Я новичок в Stackoverflow и не могу комментировать, поэтому я пишу этот новый ответ. @VM_AI, последний индекс tf.shape(output)[1] - 1. Итак, повторное использование вашего ответа:

# Let first fetch the last index of seq length
# last_index would have a scalar value
last_index = tf.shape(output)[1] - 1
# Then let reshape the output to [sequence_length,batch_size,num_units]
# for convenience
output_rs = tf.transpose(output,[1,0,2])
# Last state of all batches
last_state = tf.nn.embedding_lookup(output_rs,last_index)

Это работает для меня.

Ответ 3

output[:, -1, :]

работает с Tensorflow 1.x сейчас!

Ответ 4

Это то, что gather_nd для!

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

В вашем случае (предполагая, что sequence_length является одномерным тензором с длиной каждого элемента оси 0):

output = extract_axis_1(output, sequence_length - 1)

Теперь вывод представляет собой тензор размера [batch_size, num_cells].

Ответ 5

Вы должны иметь доступ к форме тензора output, используя tf.shape(output). Функция tf.shape() вернет 1-й тензор, содержащий размеры тензора output. В вашем примере это будет (batch_size, sequence_length, num_units)

Затем вы можете извлечь значение output_at_T как output[:, tf.shape(output)[1], :]

Ответ 6

В TensorFlow tf.shape есть функция, позволяющая получить символическую интерпретацию формы, а не None, возвращаемую output._shape[1]. И после получения последнего индекса вы можете искать с помощью tf.nn.embedding_lookup, который рекомендуется особенно когда данные, которые будут выбраны, высоки, так как это делает параллельный поиск 32 by default.

# Let first fetch the last index of seq length
# last_index would have a scalar value
last_index = tf.shape(output)[1] 
# Then let reshape the output to [sequence_length,batch_size,num_units]
# for convenience
output_rs = tf.transpose(output,[1,0,2])
# Last state of all batches
last_state = tf.nn.embedding_lookup(output_rs,last_index)

Это должно работать.

Просто чтобы уточнить, что сказал @Benoit Steiner. Его решение не будет работать, так как tf.shape вернет символическую интерпретацию значения формы, и это не может быть использовано для нарезки тензоров, т.е. Прямой индексации