TensorFlowのtf.data API

(2018-11-03)

Importing Data | TensorFlow

データを読み込み変換してイテレートする入力パイプラインを作るAPI。 通常、学習にGPU/TPUを使う場合CPU処理の間はアイドル状態となりボトルネックになるが、 パイプライン処理を行うことでCPUとGPU/TPUがなるべくアイドル状態にならないようになり、 学習時間が短縮される

Datasetの作成

from_tensor_slices()でDatasetを作成する。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # {'a': tf.float32, 'b': tf.int32}
print(dataset.output_shapes) # {'a': TensorShape([]), 'b': TensorShape([Dimension(100)])}

引数にnumpyのndarrayを渡すとtf.constant()で変換されてグラフに乗る。

dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
print(dataset.output_types) # <dtype: 'int64'>
print(dataset.output_shapes) # (3,)

データが1GBを超える場合グラフのシリアライズ上限を超えてしまうことがある。後述するinitializableイテレータの初期化時にndarrayを渡すとこれを避けられる。

tf.contrib.data.CsvDatasetでCSVからDatasetを作ることもできる。

$ cat file1.csv
a,b,c,d
1,2,3,4
2,3,4,5
6,7,8,9
filenames = ["file1.csv"]
record_defaults = [tf.float32] * 2 # Two required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[1,3])
print(dataset.output_types) # (tf.float32, tf.float32)
print(dataset.output_shapes) # (TensorShape([]), TensorShape([]))

前処理

Datasetは mapflat_mapfilterで変換でき、 batchを作ったり repeatしたり shuffleできる。

dataset = tf.data.Dataset.range(5).map(lambda d: d*2) # 0 2 4 6 8
dataset = tf.data.Dataset.from_tensor_slices((np.arange(4).reshape((2,2)))).flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x * 2)) # 0 2 4 6
dataset = tf.data.Dataset.range(5).filter(lambda d: tf.equal(tf.mod(d, 2), 0)) # 0 2 4

dataset = tf.data.Dataset.range(10).batch(4) # [0 1 2 3] [4 5 6 7] [8 9]
dataset = tf.data.Dataset.range(3).repeat(3) # 0 1 2 0 1 2 0 1 2
dataset = tf.data.Dataset.range(5).shuffle(buffer_size=10) # 1 0 3 4 2

Iteratorの作成

get_next()で次の要素が取れる。

one-shot

make_one_shot_iterator()で作れるone-shotイテレータはDatasetを一周イテレートする基本的なイテレータ。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    try:
        while True:
            print(sess.run(next_element)) # 0 1 2, ..., 99
    except tf.errors.OutOfRangeError:
        pass

initializable

make_initializable_iterator()で単一Datasetから作れるinitializableイテレータはDatasetのplaceholderを初期化できる。

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer, feed_dict={max_value: 10})
    try:
        while True:
            print(sess.run(next_element)) # 0 1 2, ..., 9
    except tf.errors.OutOfRangeError:
        pass
    print("---")
    sess.run(iterator.initializer, feed_dict={max_value: 20})
    try:
        while True:
            print(sess.run(next_element)) # 0 1 2, ..., 19
    except tf.errors.OutOfRangeError:
        pass

reinitializable

Iterator.from_structure()でtypeとshapeから作れるreinitializableイテレータは複数のDatasetで初期化できる。

training_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
validation_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)) * 2)

iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

with tf.Session() as sess:
    sess.run(training_init_op)
    try:
        while True:
            print(sess.run(next_element)) # [0 1 2] [3 4 5] [6 7 8]
    except tf.errors.OutOfRangeError:
        pass
    print("---")
    sess.run(validation_init_op)
    try:
        while True:
            print(sess.run(next_element)) # [0 2 4] [ 6  8 10] [12 14 16]
    except tf.errors.OutOfRangeError:
        pass

feedable

Iterator.from_string_handle()でplaceholderとtypeとshapeから作れるfeedableイテレータは reinitializableと同じく複数のDatasetを切り替えることができるが、 初期化はせずrunごとにDatasetのhandleをplaceholderの値として渡せる。

training_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)))
validation_dataset = tf.data.Dataset.from_tensor_slices(np.arange(9).reshape((3, 3)) * 2)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_one_shot_iterator()

with tf.Session() as sess:
    sess.run(training_init_op)
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    try:
        while True:
            print(sess.run(next_element, feed_dict={handle: training_handle})) # [0 1 2] [3 4 5] [6 7 8]
    except tf.errors.OutOfRangeError:
        pass
    try:
        while True:
            print(sess.run(next_element, feed_dict={handle: validation_handle})) # [0 2 4] [ 6  8 10] [12 14 16]
    except tf.errors.OutOfRangeError:
        pass