TensorFlowのtf.data API
tensorflowデータを読み込み変換してイテレートする入力パイプラインを作るAPI。 通常、学習にGPU/TPUを使う場合CPU処理の間はアイドル状態となりボトルネックになるが、 パイプライン処理を行うことでCPUとGPU/TPUがなるべくアイドル状態にならないようになり、 学習時間が短縮される。
Datasetの作成
from_tensor_slices()でDatasetを作成する。
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # {'a': tf.float32, 'b': tf.int32}
print(dataset.output_shapes) # {'a': TensorShape([]), 'b': TensorShape([Dimension(100)])}
引数にnumpyのndarrayを渡すとtf.constant()で変換されてグラフに乗る。
dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
print(dataset.output_types) # <dtype: 'int64'>
print(dataset.output_shapes) # (3,)
データが1GBを超える場合グラフのシリアライズ上限を超えてしまうことがある。後述するinitializableイテレータの初期化時にndarrayを渡すとこれを避けられる。
tf.contrib.data.CsvDatasetでCSVからDatasetを作ることもできる。
$ cat file1.csv
a,b,c,d
1,2,3,4
2,3,4,5
6,7,8,9
filenames = ["file1.csv"]
record_defaults = [tf.float32] * 2 # Two required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[1,3])
print(dataset.output_types) # (tf.float32, tf.float32)
print(dataset.output_shapes) # (TensorShape([]), TensorShape([]))
前処理
Datasetは mapや flat_map、 filterで変換でき、 batchを作ったり repeatしたり shuffleできる。
dataset = tf.data.Dataset.range(5).map(lambda d: d*2) # 0 2 4 6 8
dataset = tf.data.Dataset.from_tensor_slices((np.arange(4).reshape((2,2)))).flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x * 2)) # 0 2 4 6
dataset = tf.data.Dataset.range(5).filter(lambda d: tf.equal(tf.mod(d, 2), 0)) # 0 2 4
dataset = tf.data.Dataset.range(10).batch(4) # [0 1 2 3] [4 5 6 7] [8 9]
dataset = tf.data.Dataset.range(3).repeat(3) # 0 1 2 0 1 2 0 1 2
dataset = tf.data.Dataset.range(5).shuffle(buffer_size=10) # 1 0 3 4 2
Iteratorの作成
get_next()で次の要素が取れる。
one-shot
make_one_shot_iterator()で作れるone-shotイテレータはDatasetを一周イテレートする基本的なイテレータ。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(next_element)) # 0 1 2, ..., 99
except tf.errors.OutOfRangeError:
pass
initializable
make_initializable_iterator()で単一Datasetから作れるinitializableイテレータはDatasetのplaceholderを初期化できる。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 10})
try:
while True:
print(sess.run(next_element)) # 0 1 2, ..., 9
except tf.errors.OutOfRangeError:
pass
print("---")
sess.run(iterator.initializer, feed_dict={max_value: 20})
try:
while True:
print(sess.run(next_element)) # 0 1 2, ..., 19
except tf.errors.OutOfRangeError:
pass
reinitializable
Iterator.from_structure()でtypeとshapeから作れるreinitializableイテレータは複数のDatasetで初期化できる。
training_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
validation_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)) * 2)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
with tf.Session() as sess:
sess.run(training_init_op)
try:
while True:
print(sess.run(next_element)) # [0 1 2] [3 4 5] [6 7 8]
except tf.errors.OutOfRangeError:
pass
print("---")
sess.run(validation_init_op)
try:
while True:
print(sess.run(next_element)) # [0 2 4] [ 6 8 10] [12 14 16]
except tf.errors.OutOfRangeError:
pass
feedable
Iterator.from_string_handle()でplaceholderとtypeとshapeから作れるfeedableイテレータは reinitializableと同じく複数のDatasetを切り替えることができるが、 初期化はせずrunごとにDatasetのhandleをplaceholderの値として渡せる。
training_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
validation_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)) * 2)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_one_shot_iterator()
with tf.Session() as sess:
sess.run(training_init_op)
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
try:
while True:
print(sess.run(next_element, feed_dict={handle: training_handle})) # [0 1 2] [3 4 5] [6 7 8]
except tf.errors.OutOfRangeError:
pass
try:
while True:
print(sess.run(next_element, feed_dict={handle: validation_handle})) # [0 2 4] [ 6 8 10] [12 14 16]
except tf.errors.OutOfRangeError:
pass