Distributed TensorFlow
クラスタを組んでGraphを分散実行する。
クラスタは
- master: sessionを作成し、workerを制御する
- worker: 計算を行う
- ps(parameter server): 変数の値を持ち、更新する
のjobからなり、gRPCの
でやり取りする。
TensorFlow serverを立てる
各jobとURLのmapをClusterSpecにして jobとindexと併せてServerDefを作って Serverを立てる。
{
"master": [
"check-tf-config-master-34z8-0:2222"
],
"ps": [
"check-tf-config-ps-34z8-0:2222",
"check-tf-config-ps-34z8-1:2222"
],
"worker": [
"check-tf-config-worker-34z8-0:2222",
"check-tf-config-worker-34z8-1:2222"
]
}
cluster_spec_object = tf.train.ClusterSpec(cluster_spec)
server_def = tf.train.ServerDef(
cluster=cluster_spec_object.as_cluster_def(),
protocol="grpc",
job_name=job_name, # worker, master, ps
task_index=0)
server = tf.train.Server(server_def)
psのjobではserver.join()
して待ち構える。
if job_name == "ps":
server.join()
else:
# build model
WorkerにGraphを割り当てる
workerのdeviceにGraphを割り当てる。
deviceは/job:worker/replica:0/task:0/device:GPU:0
のようなフォーマットで表される。
Graphの持ち方には一つのGraphの異なる計算箇所をそれぞれのworkerが持つIn-graph replicationと、 それぞれGraphを持つBetween-graph replicationがある。これは後者の例で、replica_device_setterによってラウンドロビンで各psに変数を配置する。
with tf.device(tf.train.replica_device_setter(
cluster=cluster_spec,
worker_device=device
)):
# graph
SyncReplicasOptimizerのhookを追加
同期して変数を更新する場合、SyncReplicasOptimizerを使い、make_session_run_hookで作られるhookをMonitoredTrainingSessionに渡す。
TensorFlowのMonitoredSessionとSessionRunHookとsummaryのエラー - sambaiz-net
train_op = tf.train.SyncReplicasOptimizer(
tf.train.AdamOptimizer(self.learning_rate),
replicas_to_aggregate=self.worker_num,
total_num_replicas=self.worker_num)
hooks = [
tf.train.StopAtStepHook(last_step=args.last_step),
tf.train.CheckpointSaverHook(
'./ckpt',
save_steps=args.save_steps,
saver=saver),
train_op.make_session_run_hook(self.is_chief)
]
with tf.train.MonitoredTrainingSession(
is_chief=is_chief,
master=master,
hooks=hooks
) as sess:
while not sess.should_stop():
# sess.run()
SavedModelの出力
実行後、SavedModelを出力する。
TensorFlowのモデルをsave/loadする - sambaiz-net
MonitoredTrainingSessionのsessはshould_stop()
がtrueになったあとは使えなくなるため、
新しいsessionを作るかhookのendでする必要がある。
新しいsessionでする例。そのままrestoreするとすでにworkerが終了している場合に、そのdeviceの変数もrestoreしようとして失敗するので import_meta_graphのclear_devicesをTrueにしている。
with tf.Graph().as_default():
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state("ckpt")
saver = tf.train.import_meta_graph(
'{}.meta'.format(ckpt.model_checkpoint_path),
clear_devices=True)
saver.restore(sess, ckpt.model_checkpoint_path)
save.save(sess, "saved", signature_def_map)
hookでする例。restoreする必要がなくて良い。
class SavedModelBuilderHook(session_run_hook.SessionRunHook):
def __init__(self, export_dir, signature_def_map, tags):
self.export_dir = export_dir
self.signature_def_map = signature_def_map
self.tags = tags
def end(self, session):
session.graph._unsafe_unfinalize()
builder = tf.saved_model.builder.SavedModelBuilder(self.export_dir)
builder.add_meta_graph_and_variables(
session,
self.tags,
signature_def_map=self.signature_def_map
)
builder.save()