一文了解Batch Normalization

數據歸一化、正則化是非常重要的步驟,用於重新縮放輸入的數值,確保在反向傳播期間更好的收斂。一般來說,採用的方法是減去平均值在除以標準差。如果不這樣做,某些特徵的量綱比較大,在cost函數中將得到更大的加權。數據歸一化使得所有特徵的權重相等,量綱相同。

對於網絡的輸入,我們用tanh這個激活函數來舉例。如果不做normalization,那麼當輸入的絕對值比較大的時候,tanh激活函數的輸入會比較大,那麼tanh就會處於飽和階段,此時的神經網絡對於輸入不在敏感,我們的網絡基本上已經學不到東西了,激活函數的輸出不是+1,就是-1.

一文了解Batch Normalization

tanh激活函數的導數:

一文了解Batch Normalization

當輸入比較大的時候,tanh處於飽和階段導數幾乎為0,神經網絡幾乎學習不到東西。

而且,更麻煩的是:這種問題不僅僅出現在輸入層,在隱藏層中同樣有這個問題

Batch Normalization

總結:BN實際上是用來解決反向傳播中的梯度問題的,克服神經網絡難以訓練的弊病。BN解決了在反向傳播過程中的梯度問題(梯度消失和爆炸),同時使得不同scale的w整體更新步調一致。

BN使得,雖然更新了W的值,但是在反向傳播的時候,乘以的數不再和w的尺度相關。即尺度較大的w將獲得一個較小的梯度,在同等學習速率下其獲得的更新更少,這樣使得整體w的更新更加穩健起來。

原理

訓練深層神經網絡是複雜的,因為每一層的輸入分佈在訓練期間隨著前一層參數的改變而改變。

BN的思想是:以這樣一種方式對每一層的輸入進行歸一化,

即他們具有0的平均輸出激活和1的標準差。這是針對每一層的每一個單獨的微批次進行的,即單獨計算該微批次的平均值和方差,然後歸一化。

這類似於網絡輸入的標準化。有什麼好處那?我們知道,對網絡輸入進行規範化有助於網絡學習。但是網絡只是一系列層,其中一層的輸出會作為下一層的輸入。這意味著我們可以把上一層的輸入看做後面一個較小網絡的的第一層。在應用激活函數之前,先對一個層的輸出進行歸一化,然後再將其輸入到下一層(子網絡).

另外一個角度:

大家知道,在機器學習中一個經典假設就是源空間source domain 和 目標空間target domain 的數據分佈是一致的。 如果不一致,那麼就出現了新的機器學習問題,比如transfer learning。

convariate shift就是源空間和目標空間分佈不一致假設下的一個分支問題:它指源空間和目標空間的條件概率是一致的,但是邊緣概率不同。也就是:

一文了解Batch Normalization

仔細想想,在神經網絡中是不是這樣的那?是的!對於神經網絡的各層輸出,由於他們經過了全連接+激活函數的作用,其分佈顯然和各層的輸入信號分佈不同,而且差異會隨著網絡的增加而增大,但是對應的label是沒有變的,這邊符合了convariate shift的定義。

google的論文中提到的另一個關鍵詞Internal Covariate Shift就是這個意思:每一層網絡的輸入都會因為前一層網絡參數的變化導致其分佈發生改變,這就要求我們必須使用一個很小的學習率和對參數很好的初始化,但是這樣麼做會讓訓練過程變得很慢而且複雜。 這種現象就被作者稱為Internal Covariate Shift。通過BN可以很好的解決這個問題,並且前提是用mini-batch來訓練神經網絡。

internal應該是指對於每一個隱藏層的分析。那麼通過mini-batch來固定所有層的均值和方差就行了嗎?google說這樣是可以解決問題的,實際上均值方差相同,分佈還真不一定相同。我想,ICS只是這個問題的一種概括的說法吧,屬於一種high-level demonstration.

BN添加位置

BN被添加在全連接層和激活函數之間

。 輸入x和w相乘,經過Batch Normalization,然後再輸入到激活函數。

效果

當然,google論文指出BN會帶來大約30%的額外計算開銷

對比添加batch normalization和不添加的網絡,激活函數的輸入分佈

一文了解Batch Normalization

可以發現,激活函數的輸入分佈集中在

敏感區域,對應激活函數的導數也比較大,非常有利於網絡的學習。

經過激活函數後,激活函數輸出分佈

一文了解Batch Normalization

可以發現,經過batch normalization之後,激活函數的輸出比較平緩,在飽和階段和激活階段都有很多值。對於後面每一層都做這樣的操作,那麼每一層的輸出都會有這樣一個分佈,更有效的利用tanh的非線性變換,更有利於神經網絡的學習。

對比來看,沒有使用BN的話,激活函數的輸出基本上只有兩端最極端的值,都處於飽和階段,要麼為+1,要麼為-1,這樣的神經網絡基本上已經學不到東西了。

最後,讓我們通過一個三層神經網絡來看看batch normalization後,各層激活函數輸出的分佈來直觀感受下BN的效果:

一文了解Batch Normalization

很明顯,BN是的每一層的數據輸入都是更加有意義的

,讓每一層的值都在一個範圍內更加有效的傳遞下去。

反 normalize

一文了解Batch Normalization

在BN的論文中,不僅僅有batch normalization還有一個步驟,我們稱之為反normalize,目的是用於抵消BN的操作

為什麼要抵消那?BN作者希望,神經網絡自己可以通過學習,調節gamma和belt來學習

出前面的BN到底有沒有起到優化的作用,如果沒有起到作用,那麼通過gamma和belt來進行一些抵消操作。

從另外一個角度來看:最後的scale and shift操作是為了讓因訓練而特意引入的BN能夠有可能還原最初的輸入。即當

一文了解Batch Normalization

從而保證整個network的capacity。

關於capacity的解釋:實際上BN是在原模型基礎上增加的新操作,這個新操作很大可能會改變某層原來的輸入,當然也可能不改變,不改變的時候就是還原原來的輸入。如此一來,既可以改變同時也可以不改變,那麼模型的容納能力(capacity)就提升了。

再從這個角度試著理解下:

BN算法強行將數據拉到均值為0,方差為1的比較標準的正態分佈上來。但是這樣導致了一個問題:只利用了線性區域而導致深層網絡無意義,使得模型的表達能力下降。為了保證非線性的獲得,所以用scale和shift來伸縮或移動數據

代碼實踐

數據分佈:

一文了解Batch Normalization

中間過程激活每一層激活函數的輸入

剛開始時的輸入:

一文了解Batch Normalization

運行一段時間之後各層的輸入:

一文了解Batch Normalization

可以看到如果沒有BN,那麼很快輸入就只有0了。如果使用了BN,那麼輸入分佈在0-1之間,分佈比較好。使用的激活函數是ReLU

損失函數:

當使用ReLU激活函數時,沒有BN的話,很快網絡就學不到東西了,自然也就沒了Loss。使用BN後發現loss還是比較低的。

一文了解Batch Normalization

我們把激活函數換成tanh再試試:

剛開始各個網絡層的輸入:

一文了解Batch Normalization

運行一段時間後的輸入:

一文了解Batch Normalization

可以發現,如果不使用BN,那麼運行一段時間之後,激活函數的輸出就只剩下-1和+1了。這樣的網絡基本上學不到什麼東西,已經死掉了。但是使用BN的,激活函數的輸出依舊是分佈在-1和+1之間,比較理想。

損失函數:

一文了解Batch Normalization

代碼

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
ACTIVATION = tf.nn.tanh

N_LAYERS = 7
N_HIDDEN_UNITS = 30
def fix_seed(seed=1):
np.random.seed(seed)
tf.set_random_seed(seed)
def plot_his(inputs, inputs_norm):
# plot histogram for the inputs of every layer
for j, all_inputs in enumerate([inputs, inputs_norm]):
for i, input in enumerate(all_inputs):
plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))
plt.cla()
if i == 0:
the_range = (-7, 10)
else:
the_range = (-1, 1)
plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733')
plt.yticks(())
if j == 1:
plt.xticks(the_range)
else:
plt.xticks(())
ax = plt.gca()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
plt.title("%s normalizing" % ("Without" if j == 0 else "With"))
plt.draw()
plt.pause(0.01)
def built_net(xs, ys, norm):
def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0.0, stddev=1.0))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
# Full Connect
Wx_plus_b = tf.matmul(inputs, Weights) + biases
# Batch Normalization
if norm:
fc_mean, fc_var = tf.nn.moments(Wx_plus_b, axes=[0])
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
# apply moving average for mean and var when train on batch
ema = tf.train.ExponentialMovingAverage(decay=0.5)
def mean_var_with_update():
ema_apply_op = ema.apply([fc_mean, fc_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(fc_mean), tf.identity(fc_var)
mean, var = mean_var_with_update()
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
# Wx_plus_b = Wx_plus_b * scale + shift
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)
# Activation

if activation_function == None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
if norm:
# BN for first layer
fc_mean, fc_var = tf.nn.moments(xs, axes=[0])
scale = tf.Variable(tf.ones([1]))
shift = tf.Variable(tf.zeros([1]))
epsilon = 0.001
ema = tf.train.ExponentialMovingAverage(decay=0.5)
def mean_var_with_update():
ema_apply_op = ema.apply([fc_mean, fc_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(fc_mean), tf.identity(fc_var)
mean, var = mean_var_with_update()
xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)
# record inputs for every layer
layers_inputs = [xs]
# build hidden layer
for layer_idx in range(N_LAYERS):
layer_input = layers_inputs[layer_idx]
in_size = layer_input.get_shape()[1].value
output = add_layer(
inputs = layer_input,
in_size = in_size,
out_size = N_HIDDEN_UNITS,
activation_function = ACTIVATION,
norm = norm
)
layers_inputs.append(output)
# build output layer
prediction = add_layer(layers_inputs[-1], N_HIDDEN_UNITS, 1, activation_function=None)
cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss=cost)
return train_op, cost, layers_inputs
# main progress
if __name__ == '__main__':
fix_seed(2018)
# make up data
x_data = np.linspace(start=-7, stop=10, num=2500)[:, np.newaxis]
np.random.shuffle(x_data)
noise = np.random.normal(loc=0, scale=8, size=x_data.shape)
y_data = np.square(x_data) - 5 + noise
# plot input data
plt.scatter(x=x_data, y=y_data)
plt.show()
# prepare tf
xs = tf.placeholder(tf.float32, [None, 1])

ys = tf.placeholder(tf.float32, [None, 1])
train_op, cost, layers_inputs = built_net(xs, ys, norm=False)
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True)
# init tf
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# record cost
cost_his = []
cost_his_norm = []
plt.ion() # 打開交互模式
plt.figure(figsize=(7,3))
for i in range(250): #[0,249]
print(i)
if i % 50 == 0: #plot histgram
all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys:y_data})
plot_his(all_inputs, all_inputs_norm)
# train on batch
sess.run([train_op, train_op_norm], feed_dict={xs:x_data, ys:y_data})
# record cost
cost_his.append(sess.run(cost, feed_dict={xs:x_data, ys:y_data}))
cost_his_norm.append(sess.run(cost_norm, feed_dict={xs:x_data, ys:y_data}))
plt.ioff()
plt.figure()
plt.plot(np.arange(len(cost_his)), np.array(cost_his), label='no BN')
plt.plot(np.arange(len(cost_his)), np.array(cost_his_norm), label='BN')
plt.legend()
plt.show()


分享到:


相關文章: