После прочтения API DOC я также не могу понять использование SessionRunHook. Например, какова последовательность функций участника SessionRunHook? Это after_create_session → before_run → begin → after_run → end? И я не могу найти учебник с подробными примерами, есть ли более подробное объяснение?
Какова последовательность вызываемой функции-члена SessionRunHook?
Ответ 1
Вы можете найти учебник здесь, немного длинный, но вы можете перейти к части построения сети. Или вы можете прочитать мое краткое изложение ниже, основываясь на моем опыте.
Во-первых, вместо обычного Session следует использовать MonitoredSession.
SessionRunHook расширяет
session.run()вызовы дляMonitoredSession.
Тогда некоторые общие классы SessionRunHook можно найти здесь. Простым является LoggingTensorHook, но вы можете добавить следующую строку после импорта для просмотра журналов во время работы:
tf.logging.set_verbosity(tf.logging.INFO)
Или у вас есть возможность реализовать свой собственный класс SessionRunHook. Простой из учебника cifar10
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
где loss определяется вне класса. Этот _LoggerHook использует print для печати информации, в то время как LoggingTensorHook использует tf.logging.INFO.
Наконец, для лучшего понимания, как это работает, порядок выполнения представлен псевдокодом с MonitoredSession здесь:
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
Надеюсь, это поможет.
Ответ 2
tf.SessionRunHook позволяет вам добавлять свой код в течение каждой команды запуска сеанса, выполняемой в коде. Чтобы понять это, я создал простой пример ниже:
- Мы хотим напечатать значения потерь после каждого обновления параметров.
- Для этого мы будем использовать
SessionRunHook.
Создать график тензорного потока
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
Создание крюка
class _Hook(tf.train.SessionRunHook):
def __init__(self, loss):
self.loss = loss
def begin(self):
pass
def before_run(self, run_context):
return tf.train.SessionRunArgs(self.loss)
def after_run(self, run_context, run_values):
loss_value = run_values.results
print("loss value:", loss_value)
Создание контролируемого сеанса с помощью hook
sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])
поезд
for _ in range(10):
x_ = np.random.random((10, 2))
sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494