Destributed TensorFlowの流れとSavedModelの出力

tensorflowmachinelearning

Distributed TensorFlow

クラスタを組んでGraphを分散実行する。

クラスタは

  • master: sessionを作成し、workerを制御する
  • worker: 計算を行う
  • ps(parameter server): 変数の値を持ち、更新する

のjobからなり、gRPCの

でやり取りする。

TensorFlow serverを立てる

各jobとURLのmapをClusterSpecにして jobとindexと併せてServerDefを作って Serverを立てる。

{
    "master": [
        "check-tf-config-master-34z8-0:2222"
    ],
    "ps": [
        "check-tf-config-ps-34z8-0:2222",
        "check-tf-config-ps-34z8-1:2222"
    ],
    "worker": [
        "check-tf-config-worker-34z8-0:2222",
        "check-tf-config-worker-34z8-1:2222"
    ]
}
cluster_spec_object = tf.train.ClusterSpec(cluster_spec)
server_def = tf.train.ServerDef(
    cluster=cluster_spec_object.as_cluster_def(),
    protocol="grpc",
    job_name=job_name, # worker, master, ps 
    task_index=0)
server = tf.train.Server(server_def)

psのjobではserver.join()して待ち構える。

if job_name == "ps":
    server.join()
else:
    # build model

WorkerにGraphを割り当てる

workerのdeviceにGraphを割り当てる。 deviceは/job:worker/replica:0/task:0/device:GPU:0 のようなフォーマットで表される。

Graphの持ち方には一つのGraphの異なる計算箇所をそれぞれのworkerが持つIn-graph replicationと、 それぞれGraphを持つBetween-graph replicationがある。これは後者の例で、replica_device_setterによってラウンドロビンで各psに変数を配置する。

with tf.device(tf.train.replica_device_setter(
    cluster=cluster_spec,
    worker_device=device
)):
    # graph

SyncReplicasOptimizerのhookを追加

同期して変数を更新する場合、SyncReplicasOptimizerを使い、make_session_run_hookで作られるhookをMonitoredTrainingSessionに渡す。

TensorFlowのMonitoredSessionとSessionRunHookとsummaryのエラー - sambaiz-net

train_op = tf.train.SyncReplicasOptimizer(
    tf.train.AdamOptimizer(self.learning_rate),
    replicas_to_aggregate=self.worker_num,
    total_num_replicas=self.worker_num)

hooks = [
    tf.train.StopAtStepHook(last_step=args.last_step),
    tf.train.CheckpointSaverHook(
        './ckpt',
        save_steps=args.save_steps,
        saver=saver),
    train_op.make_session_run_hook(self.is_chief)
]
with tf.train.MonitoredTrainingSession(
    is_chief=is_chief,
    master=master,
    hooks=hooks
) as sess:
    while not sess.should_stop():
    # sess.run()

SavedModelの出力

実行後、SavedModelを出力する。

TensorFlowのモデルをsave/loadする - sambaiz-net

MonitoredTrainingSessionのsessはshould_stop()がtrueになったあとは使えなくなるため、 新しいsessionを作るかhookのendでする必要がある。

新しいsessionでする例。そのままrestoreするとすでにworkerが終了している場合に、そのdeviceの変数もrestoreしようとして失敗するので import_meta_graphのclear_devicesをTrueにしている。

with tf.Graph().as_default():
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state("ckpt")
        saver = tf.train.import_meta_graph(
            '{}.meta'.format(ckpt.model_checkpoint_path),
            clear_devices=True)
        saver.restore(sess, ckpt.model_checkpoint_path)
        save.save(sess, "saved", signature_def_map)

hookでする例。restoreする必要がなくて良い。

class SavedModelBuilderHook(session_run_hook.SessionRunHook):
    def __init__(self, export_dir, signature_def_map, tags):
        self.export_dir = export_dir
        self.signature_def_map = signature_def_map
        self.tags = tags

    def end(self, session):
        session.graph._unsafe_unfinalize()
        builder = tf.saved_model.builder.SavedModelBuilder(self.export_dir)
        builder.add_meta_graph_and_variables(
            session,
            self.tags,
            signature_def_map=self.signature_def_map
        )
        builder.save()