Возможно ли, чтобы обучаемая переменная не обучалась?

Я создал переменную обучаемый в области. Позже я вошел в ту же область действия, задал область reuse_variables и использовал get_variable для извлечения той же переменной. Однако я не могу установить переменное обучаемое свойство False. Моя строка get_variable похожа:

weight_var = tf.get_variable('weights', trainable = False)

Но переменная 'weights' все еще находится на выходе tf.trainable_variables.

Можно ли установить общий флаг trainable на False с помощью get_variable?

Причина, по которой я хочу сделать это, заключается в том, что я пытаюсь повторно использовать низкоуровневые фильтры, предварительно обученные из сети VGG в моей модели, и я хочу построить график, как раньше, получить переменную весов и назначить VGG для весовой переменной, а затем фиксировать их на следующем этапе обучения.

Ответ 1

После просмотра документации и кода я не смог найти способ удалить переменную из TRAINABLE_VARIABLES.

Вот что происходит:

  • При первом tf.get_variable('weights', trainable=True) переменная добавляется в список TRAINABLE_VARIABLES.
  • Во второй раз, когда вы вызываете tf.get_variable('weights', trainable=False), вы получаете ту же переменную, но аргумент trainable=False не действует, так как переменная уже присутствует в списке TRAINABLE_VARIABLESнет способа убери оттуда)

Первое решение

При вызове метода minimize оптимизатора (см . var_list=[...]) вы можете передать в качестве аргумента переменную var_list=[...] с переменными, которые вы хотите оптимизировать.

Например, если вы хотите заморозить все слои VGG, кроме двух последних, вы можете передать веса последних двух слоев в var_list.

Второе решение

Вы можете использовать tf.train.Saver() чтобы сохранить переменные и восстановить их позже (см. Это руководство).

  • Сначала вы тренируете всю свою модель VGG со всеми обучаемыми переменными. Вы сохраняете их в файле контрольных точек, вызывая saver.save(sess, "/path/to/dir/model.ckpt").
  • Затем (в другом файле) вы тренируете вторую версию с необучаемыми переменными. Вы загружаете переменные, ранее сохраненные с помощью saver.restore(sess, "/path/to/dir/model.ckpt").

При желании вы можете сохранить только некоторые переменные в вашем файле контрольных точек. Смотрите документ для получения дополнительной информации.

Ответ 2

Если вы хотите обучить или оптимизировать только определенные уровни предварительно обученной сети, это то, что вам нужно знать.

Метод minimize var_list принимает необязательный аргумент var_list, список переменных, которые необходимо скорректировать с помощью обратного распространения.

Если вы не укажете var_list, оптимизатор может отрегулировать любую переменную TF на графике. Когда вы указываете некоторые переменные в var_list, TF сохраняет все остальные переменные постоянными.

Вот пример сценария, который использовал Йонбрунер и его соавтор.

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

Он находит все переменные, которые они определили ранее и имеют "g_" в имени переменной, помещает их в список и запускает на них оптимизатор ADAM.

Вы можете найти соответствующие ответы здесь на Quora

Ответ 3

Чтобы удалить переменную из списка обучаемых переменных, вы можете сначала получить доступ к коллекции с помощью: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) Там, trainable_collection содержит ссылку на коллекцию обучаемых переменных. Если вы извлекаете элементы из этого списка, например, trainable_collection.pop(0), вы удалите соответствующую переменную из обучаемых переменных, и, таким образом, эта переменная не будет обучена.

Хотя это работает с pop, я все еще пытаюсь найти способ правильно использовать remove с правильным аргументом, поэтому мы не зависим от индекса переменных.

РЕДАКТИРОВАТЬ: Учитывая, что у вас есть имя переменных в графе (вы можете получить это, изучив протобуф графика или, что проще, используя Tensorboard), вы можете использовать его для циклического просмотра списка обучаемых переменных, а затем удалить переменные из обучаемой коллекции. Пример: скажем, что я хочу, чтобы переменные с именами "batch_normalization/gamma:0" и "batch_normalization/beta:0" НЕ обучались, но они уже добавлены в коллекцию TRAINABLE_VARIABLES. Что я могу сделать, это:

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

"Это удалит две переменные из коллекции, и они больше не будут обучаться.

Ответ 4

Вы можете использовать tf.get_collection_ref, чтобы получить ссылку на коллекцию, а не tf.get_collection