在使用TensorFlow进行模型训练的时候,我们一般不会在每一步训练的时候输入所有训练样本数据,而是通过batch的方式,每一步都随机输入少量的样本数据,这样可以防止过拟合。
所以,对训练样本的shuffle和batch是很常用的操作。
这里再说明一点,为什么需要打乱训练样本即shuffle呢?
举个例子:比如我们在做一个分类模型,前面部分的样本的标签都是A,后面部分的样本的标签全是B,那你如果不打乱样本顺序的话,就会出现前面训练出来的模型,在预测的时候会偏向于输出A,因为模型一直在标签A的方向拟合,而后面的模型,会偏向于预测B
直接看代码例子,有详细注释!!
<code>import
tensorflow as tf
import
numpy as np
d
=np.arange(0,60).reshape([6, 10])
data
=tf.data.Dataset.from_tensor_slices(d)
data
=data.shuffle(buffer_size=3)
data
=data.batch(4)
data
=data.repeat(2)
iters
=data.make_one_shot_iterator()
batch
=iters.get_next()
sess
=tf.Session()
sess.run(batch)
/<code>
<code>In
[21]:
d
Out[21]:
array([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[10,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[20,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[30,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[40,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
],
[50,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
]])
In
[22]:
sess.run(batch)
Out[22]:
array([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[30,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[20,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[10,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
]])
In
[23]:
sess.run(batch)
Out[23]:
array([[40,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
],
[50,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
]])
/<code>
从输出结果可以看出:
- shuffle是按顺序将数据放入buffer里面的;
- 当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。
那么,当repeat函数在shuffle之前会怎么样呢?如下:
<code>data
=data
.repeat(2
) data
=data
.shuffle(buffer_size=3
) data
=data
.batch(4
)/<code>
<code>In
[25]:
sess.run(batch)
Out[25]:
array([[10,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[20,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[40,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
]])
In
[26]:
sess.run(batch)
Out[26]:
array([[50,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[30,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[30,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
]])
In
[27]:
sess.run(batch)
Out[27]:
array([[10,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[50,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
],
[20,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[40,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
]])
/<code>
可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。