Batch Normalization的TensorFlow实现

tf.nn.moments函数

函数定义如下:

def moments(x, axes, name=None, keep_dims=False)

1.函数的输入

x: 输入数据,格式一般为:[batchsize, height, width, kernels]

axes: List,在哪个维度上计算,比如:[0, 1, 2]

name: 操作的名称

keep_dims: 是否保持维度

2.函数的输出

mean: 均值

variance: 方差

3.使用举例

img = tf.Variable(tf.random_normal([128, 32, 32, 64]))

axis = list(range(len(img.get_shape()) - 1))

mean, variance = tf.nn.moments(img, axis)

tf.nn.batch_normalization函数

函数定义如下:

def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)

在使用batch_normalization的时候,需要去除网络中的bias。

1.函数的输入

x: 输入的Tensor数据

mean: Tensor的均值

variance: Tensor的方差

offset: offset Tensor, 一般初始化为0,可训练

scale: scale Tensor,一般初始化为1,可训练

variance_epsilon: 一个小的浮点数,避免除数为0,一般取值0.001

name: 操作的名称

2.算法原理

Batch Normalization的TensorFlow实现

李宏毅深度学习笔记

使用示例

Batch Normalization的TensorFlow实现


分享到:


相關文章: