Tensorflow.train.import_meta_graph не работает?

Я пытаюсь просто сохранить и восстановить график, но самый простой пример не работает должным образом (это делается с использованием версии 0.9.0 или 0.10.0 на Linux 64 без CUDA с использованием python 2.7 или 3.5.2)

Сначала я сохраняю график следующим образом:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

Это создает файл "файл", который не является пустым, а также устанавливает g1 в то, что выглядит как правильное определение графа.

Затем я пытаюсь восстановить этот график:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

Это работает без ошибок, но ничего не возвращает.

Может ли кто-нибудь предоставить необходимый код, просто просто сохранить график для "v4" и полностью восстановить его, чтобы запуск этого в новом сеансе принесет тот же результат?

Ответ 1

Чтобы повторно использовать MetaGraphDef, вам нужно будет записать имена интересных тензоров в исходном графике. Например, в первой программе установите явный аргумент name в определении v1, v2 и v4:

v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
# ...
v4 = tf.add(v3, c1, name="v4")

Затем вы можете использовать имена строк тензоров в исходном графе при вызове sess.run(). Например, следующий фрагмент должен работать:

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")

sess = tf.Session()
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})

В качестве альтернативы вы можете использовать tf.get_default_graph().get_tensor_by_name(), чтобы получить объекты tf.Tensor для интересующих тензоров, которые затем можно передать на sess.run():

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
g = tf.get_default_graph()

v1 = g.get_tensor_by_name("v1:0")
v2 = g.get_tensor_by_name("v2:0")
v4 = g.get_tensor_by_name("v4:0")

sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})

ОБНОВЛЕНИЕ. На основе обсуждения в комментариях здесь приведен полный пример сохранения и загрузки, включая сохранение содержимого переменной. Это иллюстрирует сохранение переменной путем удвоения значения переменной vx в отдельной операции.

Сохранение:

import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

Восстановление:

import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

Суть в том, что для использования сохраненной модели вы должны помнить имена, по крайней мере, некоторых из узлов (например, учебный op, входной заполнитель, тензор оценки и т.д.). MetaGraphDef хранит список переменных, которые содержатся в модели, и помогает восстановить их с контрольной точки, но вам необходимо восстановить тензоры/операции, используемые для обучения/оценки модели самостоятельно.