Deep LearningのBatch Normalizationの効果をTensorFlowで確認する

(2018-11-14)

Batch Normalizationとは

Deep Learningでは各層の学習を同時に行うため、前の層の変更によって各層の入力の分布が変わってしまうinternal covariate shiftという現象が起こり、そのためにパラメータの初期化をうまくやる必要があったり、学習率を大きくできず多くのステップを要する。 以下の論文で発表されたBatch Normalization(BN)は各層の入力を正規化して分布を固定することでこれを解決するというもの。 画像認識のコンテストILSVRC 2015で1位を取ったResNet(Residual Network)でも使われている。

Sergey Ioffe, Christian Szegedy (2015) Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

具体的にはWx+bと活性化関数の間にBNの層を入れる。μ、σ^2は入力xの平均と分散。 単に正規化するだけでは表現力が下がってしまうのでγとβでスケールやシフトできるようにする。これらの変数は他のパラメータと同様に学習させる。

BN層の演算

TensorFlowでの確認

TensorFlowではbatch_normalization()がすでに実装されているのでこれを使う。

以下のCNNで学習率を高めに設定しBNありなしの結果を比較する。学習データはmnist。MonitoredSessionでcostをsummaryとして出力しTensorBoardで見られるようにしている。

TensorBoardでsummaryやグラフを見る - sambaiz-net

TensorFlowのMonitoredSessionとSessionRunHookとsummaryのエラー - sambaiz-net

import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

train = pd.read_csv('./train.csv')
(x_train, x_valid ,y_train, y_valid) = train_test_split(
    train.drop('label', axis=1).values.reshape((-1, 28, 28, 1)), 
    np.identity(10)[train['label']], 
    test_size = 0.1, random_state = 100)

print("data shape: {}, label shape {}".format(x_train.shape, y_train.shape))

tf.reset_default_graph()

class ConvBnRelu:
    def __init__(self, filters, kernel_size):
        self.filters = filters
        self.kernel_size= kernel_size
        
    def __call__(self, x, use_bn, is_training):
        h = tf.layers.Conv2D(filters=self.filters, kernel_size=self.kernel_size)(x)
        h = tf.cond(
            use_bn,
            true_fn=lambda: tf.layers.batch_normalization(h, training=is_training),
            false_fn=lambda: h
        )
        return tf.nn.relu(h)

is_training = tf.placeholder(tf.bool, shape=())
use_bn = tf.placeholder(tf.bool, shape=())
x = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32)
t = tf.placeholder(tf.float32, [None, 10])

with tf.name_scope("Conv1"):
    h = ConvBnRelu(filters=32, kernel_size= [3, 3])(x, use_bn, is_training)
    h = tf.layers.MaxPooling2D(pool_size=[2, 2], strides=2)(h)

with tf.name_scope("Conv2"):
    h = ConvBnRelu(filters= 64, kernel_size= [3, 3])(h, use_bn, is_training)
    h = tf.layers.MaxPooling2D(pool_size=[2, 2], strides=2)(h)

h = tf.layers.Flatten()(h)
y = tf.layers.Dense(units=10, activation=tf.nn.softmax)(h)

global_step=tf.train.get_or_create_global_step()
cost = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, y)), axis=1))
summary_cost = tf.summary.scalar('cost', cost)
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost, global_step=global_step)

EPOCH_NUM = 5
BATCH_SIZE = 100

hooks = [
    tf.train.StopAtStepHook(last_step=EPOCH_NUM*(len(x_train) // BATCH_SIZE))
]

init = tf.global_variables_initializer()

for bn in [False, True]:
    epoch = -1
    print("--- use binary norm: {} ---".format(bn))
    with tf.train.MonitoredTrainingSession(
        hooks=hooks, 
        summary_dir="/home/jovyan/summary",
        save_summaries_steps=100) as sess:    
        sess.run(init, feed_dict={x: x_valid, t: y_valid, is_training: False, use_bn: bn})
        while not sess.should_stop():
            epoch += 1
            y_pred, cost_valid, _ = sess.run([y, cost, summary_cost], feed_dict={x: x_valid, t: y_valid, is_training: False, use_bn: bn})
            print("epoch: {:2d}, cost: {:.4f}, accuracy: {:.4f}".format(
                epoch, cost_valid, accuracy_score(y_pred.argmax(axis=1), y_valid.argmax(axis=1))))
            x_train, y_train = shuffle(x_train, y_train, random_state=100)
            for batch in range(len(x_train) // BATCH_SIZE):
                start = batch * BATCH_SIZE
                end = start + BATCH_SIZE
                sess.run(optimizer, feed_dict={x: x_train[start:end], t: y_train[start:end], is_training: True, use_bn: bn})

BNを行わなかったときの結果。コストを下げられていない。ちなみに学習率を0.01から0.001にしたら下げられるようになった。

BNしない場合のコスト

一方、BNを行ったときの結果がこれ。学習率はそのままで順調にコストを下げることができている。

BNした場合のコスト