deprecatedになったSupervisorの後継。
MonitoredTrainingSessionで学習用のMonitoredSessionを生成する。 このコンストラクタの引数でcheckpoint_dirを渡すと内部でCheckpointSaverHookが 追加されるようになっていて、restoreしたり指定したタイミングでsaveしたりしてくれる。
なので今回明示的に渡すhooksは 指定したstepに到達したら止めてくれる、StopAtStepHookのみ。
should_stop()がTrueな状態でsession.run()
しようとするとRun called even after should_stop requested.
のエラーが出るため、
今回は新しいsessionを作ってAccuracyを返しているが、hookでやった方がrestoreする必要がないので良さそうだ。
Destributed TensorFlowの流れとSavedModelの出力 - sambaiz-net
全体のコードはここ。
def train(self, learning_rate, variable_default_stddev, bias_default, last_step=800):
test_images = self.images[:500]
test_labels = self.labels[:500]
train_batch = Batch(self.images[500:], self.labels[500:])
with tf.Graph().as_default():
global_step=tf.train.get_or_create_global_step()
g = MNIST_CNN(learning_rate, variable_default_stddev, bias_default).graph()
saver = tf.train.Saver()
savedir = './ckpt-{}-{}-{}'.format(learning_rate, variable_default_stddev, bias_default)
hooks = [
tf.train.StopAtStepHook(last_step=last_step)
]
with tf.train.MonitoredTrainingSession(
hooks=hooks,
checkpoint_dir=savedir,
save_checkpoint_secs = 300,
) as sess:
sess.run(global_step)
while not sess.should_stop():
# step = sess.run(global_step)
images, labels = train_batch.get_next(500)
sess.run(g["op"]["train_step"], feed_dict={
g["placeholder"]["x"]: list(images),
g["placeholder"]["y"]: list(labels),
})
with tf.Session() as sess:
self._restore(sess, saver, savedir)
return sess.run(g["op"]["accuracy"], feed_dict={
g["placeholder"]["x"]: list(test_images),
g["placeholder"]["y"]: list(test_labels)
}), savedir
hooksに渡すSessionRunHookは以下のメソッドからなる。
before_run()
で返すSessionRunArgsのfeed_dictはrunで渡すfeed_dictとmergeされ、
fetchesはrunするたびに毎回評価される。
class MySessionRunHook:
def __init__(self, feed_dict):
self.feed_dict = feed_dict
self.a = tf.placeholder(tf.float32, name="a")
def begin(self):
"""Called once before using the session."""
print("begin")
def after_create_session(self, session, coord):
"""Called when new TensorFlow session is created."""
print("after_create_session")
def before_run(self, run_context):
"""Called before each call to run()."""
return SessionRunArgs(fetches={"a": self.a}, feed_dict=self.feed_dict)
def after_run(self, run_context, run_values):
"""Called after each call to run()."""
print("after_run {} {}".format(run_values.results["a"], run_context.session.run(self.a, feed_dict={self.a: 10})))
def end(self, session):
"""Called at the end of session."""
print("end")
したがって、hooks内のfetchesにfeedする必要があるものが含まれる場合、
runする対象がfeedする必要がない場合でもYou must feed a value for placeholder tensor
のエラーが出ることになる。
また、hooks内とrunの引数のfeed_dictが衝突した場合もエラーになる。
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
x = tf.placeholder(tf.float32, name="x")
y = tf.placeholder(tf.float32, name="y")
z = x + y
hook1 = MySessionRunHook({x: 2})
hook2 = MySessionRunHook({y: 3})
hooks=[hook1, hook2]
with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
try:
print(sess.run(global_step))
except tf.errors.InvalidArgumentError as err:
print(err) # You must feed a value for placeholder tensor 'a_1' with dtype float ...
print(sess.run([global_step, z], feed_dict={hook1.a: 2, hook2.a: 3})) # [0, 5.0]
try:
print(sess.run(y, feed_dict={x: 10}))
except RuntimeError as err:
print(err) # Same tensor is fed by a SessionRunHook and user. Conflict(s): [<tf.Tensor 'x:0' shape=<unknown> dtype=float32>]
上で書いたようにMonitoredTrainingSessionは内部でCheckpointSaverHookを持っていて、 これがfetchesでsummaryを返しているので、一つでもsummaryを入れるとrunするときにそれに必要なfeed_dictを渡さないとエラーになる。
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
x = tf.placeholder(tf.float32, name="x")
y = tf.placeholder(tf.float32, name="y")
z = x + y
tf.summary.scalar("z", z)
with tf.train.MonitoredTrainingSession(checkpoint_dir="./aaa", save_summaries_steps=1) as sess:
# ng
# sess.run(global_step) => You must feed a value for placeholder tensor 'y' with dtype float
# sess.run(z, feed_dict={x: 10, y:20})
step, zz = sess.run([global_step, z], feed_dict={x: 10, y:20}) # ok