Batch Normalization的TensorFlow實現

tf.nn.moments函數

函數定義如下:

def moments(x, axes, name=None, keep_dims=False)

1.函數的輸入

x: 輸入數據,格式一般為:[batchsize, height, width, kernels]

axes: List,在哪個維度上計算,比如:[0, 1, 2]

name: 操作的名稱

keep_dims: 是否保持維度

2.函數的輸出

mean: 均值

variance: 方差

3.使用舉例

img = tf.Variable(tf.random_normal([128, 32, 32, 64]))

axis = list(range(len(img.get_shape()) - 1))

mean, variance = tf.nn.moments(img, axis)

tf.nn.batch_normalization函數

函數定義如下:

def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)

在使用batch_normalization的時候,需要去除網絡中的bias。

1.函數的輸入

x: 輸入的Tensor數據

mean: Tensor的均值

variance: Tensor的方差

offset: offset Tensor, 一般初始化為0,可訓練

scale: scale Tensor,一般初始化為1,可訓練

variance_epsilon: 一個小的浮點數,避免除數為0,一般取值0.001

name: 操作的名稱

2.算法原理

Batch Normalization的TensorFlow實現

李宏毅深度學習筆記

使用示例

Batch Normalization的TensorFlow實現


分享到:


相關文章: