TensorFlowのMonitoredSessionとSessionRunHookとsummaryのエラー

(2018-07-01)

MonitoredSession

deprecatedになったSupervisorの後継。

MonitoredTrainingSessionで学習用のMonitoredSessionを生成する。 このコンストラクタの引数でcheckpoint_dirを渡すと内部でCheckpointSaverHook追加されるようになっていて、restoreしたり指定したタイミングでsaveしたりしてくれる。

なので今回明示的に渡すhooksは 指定したstepに到達したら止めてくれるStopAtStepHookのみ。

should_stop()がTrueな状態でsession.run()しようとするとRun called even after should_stop requested.のエラーが出るため、学習後Accuracyを返すには新しいsessionにする必要がある。

全体のコードはここ

def train(self, learning_rate, variable_default_stddev, bias_default, last_step=800):
  test_images = self.images[:500]
  test_labels = self.labels[:500]
  train_batch = Batch(self.images[500:], self.labels[500:])

  with tf.Graph().as_default():
    global_step=tf.train.get_or_create_global_step()
    g = MNIST_CNN(learning_rate,  variable_default_stddev, bias_default).graph()
    saver = tf.train.Saver()
    savedir = './ckpt-{}-{}-{}'.format(learning_rate, variable_default_stddev, bias_default)
    hooks = [
      tf.train.StopAtStepHook(last_step=last_step)
    ]
    with tf.train.MonitoredTrainingSession(
      hooks=hooks,
      checkpoint_dir=savedir,
      save_checkpoint_secs = 300,
    ) as sess:
      sess.run(global_step)
      while not sess.should_stop():
        # step = sess.run(global_step)
        images, labels = train_batch.get_next(500)
        sess.run(g["op"]["train_step"], feed_dict={
          g["placeholder"]["x"]: list(images), 
          g["placeholder"]["y"]: list(labels),
        })
    with tf.Session() as sess:
      self._restore(sess, saver, savedir)
      return sess.run(g["op"]["accuracy"], feed_dict={
        g["placeholder"]["x"]: list(test_images), 
        g["placeholder"]["y"]: list(test_labels)
      }), savedir

hooksに渡すSessionRunHookは以下のメソッドからなる。 before_run()で返すSessionRunArgsのfeed_dictはrunで渡すfeed_dictとmergeされ、 fetchesはrunするたびに毎回評価される。

class MySessionRunHook:
  def __init__(self, feed_dict):
    self.feed_dict = feed_dict
    self.a = tf.placeholder(tf.float32, name="a")
      
  def begin(self):
    """Called once before using the session."""
    print("begin")

  def after_create_session(self, session, coord):
    """Called when new TensorFlow session is created."""
    print("after_create_session")
      
  def before_run(self, run_context):
    """Called before each call to run()."""
    return SessionRunArgs(fetches={"a": self.a}, feed_dict=self.feed_dict)
  
  def after_run(self, run_context, run_values):
    """Called after each call to run()."""
    print("after_run {} {}".format(run_values.results["a"], run_context.session.run(self.a, feed_dict={self.a: 10})))

  def end(self, session):
    """Called at the end of session."""
    print("end")

したがって、hooks内のfetchesにfeedする必要があるものが含まれる場合、 runする対象がfeedする必要がない場合でもYou must feed a value for placeholder tensorのエラーが出ることになる。 また、hooks内とrunの引数のfeed_dictが衝突した場合もエラーになる。

with tf.Graph().as_default():
  global_step = tf.train.get_or_create_global_step()
  x = tf.placeholder(tf.float32, name="x")
  y = tf.placeholder(tf.float32, name="y")
  z = x + y
  hook1 = MySessionRunHook({x: 2})
  hook2 = MySessionRunHook({y: 3})
  hooks=[hook1, hook2]
  with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
    try:
      print(sess.run(global_step))
    except tf.errors.InvalidArgumentError as err:
      print(err) # You must feed a value for placeholder tensor 'a_1' with dtype float ...
    print(sess.run([global_step, z], feed_dict={hook1.a: 2, hook2.a: 3})) # [0, 5.0]
    try:
      print(sess.run(y, feed_dict={x: 10}))
    except RuntimeError as err:
      print(err) # Same tensor is fed by a SessionRunHook and user. Conflict(s): [<tf.Tensor 'x:0' shape=<unknown> dtype=float32>]     

上で書いたようにMonitoredTrainingSessionは内部でCheckpointSaverHookを持っていて、 これがfetchesでsummaryを返しているので、一つでもsummaryを入れるとrunするときにそれに必要なfeed_dictを渡さないとエラーになる。

with tf.Graph().as_default():
  global_step = tf.train.get_or_create_global_step()
  x = tf.placeholder(tf.float32, name="x")
  y = tf.placeholder(tf.float32, name="y")
  z = x + y
  tf.summary.scalar("z", z)
  with tf.train.MonitoredTrainingSession(checkpoint_dir="./aaa", save_summaries_steps=1) as sess:
    # ng
    # sess.run(global_step) => You must feed a value for placeholder tensor 'y' with dtype float
    # sess.run(z, feed_dict={x: 10, y:20})
    step, zz = sess.run([global_step, z], feed_dict={x: 10, y:20}) # ok