Как я могу ускорить это вычисление Keras Attention?

Я написал собственный слой keras для AttentiveLSTMCell и AttentiveLSTM(RNN) в соответствии с новым подходом keras к AttentiveLSTM(RNN). Этот механизм внимания описан Bahdanau, где в модели кодировщика/декодера создается "контекстный" вектор из всех выходов кодировщика и скрытого состояния декодера. Затем я добавляю вектор контекста на каждый временной интервал к входу.

Модель используется для создания агента Dialog, но очень похожа на модели NMT в архитектуре (аналогичные задачи).

Однако, добавив этот механизм внимания, я несколько раз сократил обучение своей сети, и мне очень хотелось бы знать, как я мог бы написать часть кода, которая замедляет ее настолько эффективно.

Основная задача вычисления выполняется здесь:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])

в call методе в AttentiveLSTMCell (один временный шаг).

Полный код можно найти здесь. Если необходимо предоставить некоторые данные и способы взаимодействия с моделью, то я могу это сделать.

Есть идеи? Я, конечно же, тренируюсь на GPU, если здесь есть что-то умное.

Ответ 1

Я бы рекомендовал тренировать вашу модель, используя relu, а не tanh, так как эта операция значительно быстрее вычисляется. Это позволит сэкономить время вычислений по порядку ваших учебных примеров. * Средняя длина последовательности на пример * количество эпох.

Кроме того, я бы оценил улучшение производительности при добавлении вектора контекста, имея в виду, что это замедлит ваш цикл итерации по другим параметрам. Если это не даст вам большого улучшения, возможно, стоит попробовать другие подходы.

Ответ 2

Вы изменили класс LSTM, который хорош для вычислений ЦП, но вы упомянули, что вы тренируетесь на GPU.

Я рекомендую заглянуть в рекуррентную реализацию cudnn или далее в используемую часть tf. Может быть, вы можете расширить код там.