TensorFlow: восстановление переменных из нескольких контрольных точек

У меня следующая ситуация:

  • У меня есть две модели, написанные в двух отдельных сценариях:

  • Модель A состоит из переменных a1, a2 и a3 и написана в A.py

  • Модель B состоит из переменных b1, b2 и b3 и написана на B.py

В каждом из A.py и B.py у меня есть tf.train.Saver, который сохраняет контрольную точку всех локальных переменных и позволяет вызвать файлы контрольной точки ckptA и ckptB соответственно.

Теперь я хочу создать модель C, которая использует a1 и b1. Я могу сделать так, чтобы одно и то же имя переменной для a1 использовалось как в A, так и в C с помощью var_scope (и то же самое для b1).

Вопрос в том, как я могу загрузить a1 и b1 из ckptA и ckptB в модель C? Например, будет ли работать следующее?

saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)

Будет ли возникать ошибка, если вы попытаетесь восстановить один и тот же сеанс дважды? Будет ли он жаловаться на отсутствие выделенных "слотов" для дополнительных переменных (b2, b3, a2, a3) или просто восстановить переменные, которые он может, и только жаловаться, если есть некоторые другие переменные в C, которые неинициализированы?

Я пытаюсь написать код для проверки этого сейчас, но мне бы хотелось увидеть канонический подход к этой проблеме, потому что это часто встречается при попытке повторно использовать некоторые предварительно подготовленные веса.

Спасибо!

Ответ 1

Вы получили бы tf.errors.NotFoundError, если бы попытались использовать заставку (по умолчанию представляющую все шесть переменных) для восстановления с контрольной точки, которая не содержит всех переменных, которые представляет заставка. (Обратите внимание, однако, что вы можете называть Saver.restore() несколько раз в одном сеансе для любого подмножества переменных, если все запрошенные переменные присутствуют в соответствующем файле.)

Канонический подход заключается в определении двух отдельных tf.train.Saver экземпляров, охватывающих все подмножества переменных, которые полностью содержатся в одной контрольной точке. Например:

saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])

saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)

В зависимости от того, как создается ваш код, если у вас есть указатели на tf.Variable объекты с именем a1 и b1 в локальной области, вы можете здесь остановиться.

С другой стороны, если переменные a1 и b1 определены в отдельных файлах, вам может понадобиться сделать что-то творческое для получения указателей на эти переменные. Хотя это и не идеально, обычно люди используют общий префикс, например, следующим образом (при условии, что имена переменных "a1:0" и "b1:0" соответственно):

saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])

Последнее замечание: вам не нужно предпринимать героические усилия, чтобы переменные имели одинаковые имена в и C. Вы можете передать словарь name-to- Variable в качестве первого аргумента в tf.train.Saver и, таким образом, переназначить имена в файле контрольной точки на Variable объекты в вашем коде. Это помогает, если A.py и B.py имеют аналогично названные переменные, или если в C.py вы хотите упорядочить код модели из этих файлов в tf.name_scope().