TensorFlowのモデルをsave/loadする

(2018-06-22)

SavedModelBuilderで モデルを言語に依存しないSavedModelのprotobufにして保存できる。 SavedModelにはSaverによって出力されるCheckpointを共有する一つ以上のMetaGraphDef含む

import tensorflow as tf

def build_signature(signature_inputs, signature_outputs):
    return tf.saved_model.signature_def_utils.build_signature_def(
        signature_inputs, signature_outputs,
        tf.saved_model.signature_constants.REGRESS_METHOD_NAME)

def save(sess, export_dir, signature_def_map):
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map=signature_def_map
    )
    builder.save()

import shutil
import os.path
export_dir = "./saved_model"
if os.path.exists(export_dir):
    shutil.rmtree(export_dir)
    
with tf.Graph().as_default():
    a = tf.placeholder(tf.float32, name="a")
    b = tf.placeholder(tf.float32, name="b")
    c = tf.add(a, b, name="c")

    v = tf.placeholder(tf.float32, name="v")
    w = tf.Variable(0.0, name="w")
    x = w.assign(tf.add(v, w))
    
    sv = tf.train.Supervisor()
    with sv.managed_session() as sess:
        print(sess.run(c, feed_dict={a: 1, b: 2})) # 3.0
        print(sess.run(x, feed_dict={v: 2})) # 2.0
        print(sess.run(x, feed_dict={v: 3})) # 5.0
        # https://github.com/tensorflow/tensorflow/issues/11549
        sess.graph._unsafe_unfinalize()
        save(sess, export_dir, {
            "add": build_signature({
                "a": tf.saved_model.utils.build_tensor_info(a),
                "b":tf.saved_model.utils.build_tensor_info(b)
            }, {
                "c": tf.saved_model.utils.build_tensor_info(c)
            }),
             "accumulate": build_signature({
                "v": tf.saved_model.utils.build_tensor_info(v),
            }, {
                "x": tf.saved_model.utils.build_tensor_info(x)
            })
        })
$ ls saved_model/
saved_model.pb  variables

loadしてsess.runできる。variableの値も保存されている。

with tf.Graph().as_default():
    with tf.Session() as sess:
        meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
        print(sess.run(
            meta_graph_def.signature_def["accumulate"].outputs["x"].name, # Assign:0
            feed_dict={
                meta_graph_def.signature_def["accumulate"].inputs["v"].name: 3, # v:0
            }
        )) # 8.0