TensorFlow中如何优雅给模型批量输入训练样本

在使用TensorFlow进行模型训练的时候,我们一般不会在每一步训练的时候输入所有训练样本数据,而是通过batch的方式,每一步都随机输入少量的样本数据,这样可以防止过拟合。

所以,对训练样本的shuffle和batch是很常用的操作。

这里再说明一点,为什么需要打乱训练样本即shuffle呢?

举个例子:比如我们在做一个分类模型,前面部分的样本的标签都是A,后面部分的样本的标签全是B,那你如果不打乱样本顺序的话,就会出现前面训练出来的模型,在预测的时候会偏向于输出A,因为模型一直在标签A的方向拟合,而后面的模型,会偏向于预测B

直接看代码例子,有详细注释!!

<code>

import

tensorflow as tf

import

numpy as np

d

=

np.arange(0,60).reshape([6, 10])

data

=

tf.data.Dataset.from_tensor_slices(d)

data

=

data.shuffle(buffer_size=3)

data

=

data.batch(4)

data

=

data.repeat(2)

iters

=

data.make_one_shot_iterator()

batch

=

iters.get_next()

sess

=

tf.Session()

sess.run(batch)

/<code>
<code>

In

[21]:

d

Out[21]:

array([[

0

,

1

,

2

,

3

,

4

,

5

,

6

,

7

,

8

,

9

],

     

[10,

11

,

12

,

13

,

14

,

15

,

16

,

17

,

18

,

19

],

     

[20,

21

,

22

,

23

,

24

,

25

,

26

,

27

,

28

,

29

],

     

[30,

31

,

32

,

33

,

34

,

35

,

36

,

37

,

38

,

39

],

     

[40,

41

,

42

,

43

,

44

,

45

,

46

,

47

,

48

,

49

],

     

[50,

51

,

52

,

53

,

54

,

55

,

56

,

57

,

58

,

59

]])

In

[22]:

sess.run(batch)

Out[22]:

array([[

0

,

1

,

2

,

3

,

4

,

5

,

6

,

7

,

8

,

9

],

     

[30,

31

,

32

,

33

,

34

,

35

,

36

,

37

,

38

,

39

],

     

[20,

21

,

22

,

23

,

24

,

25

,

26

,

27

,

28

,

29

],

     

[10,

11

,

12

,

13

,

14

,

15

,

16

,

17

,

18

,

19

]])

In

[23]:

sess.run(batch)

Out[23]:

array([[40,

41

,

42

,

43

,

44

,

45

,

46

,

47

,

48

,

49

],

     

[50,

51

,

52

,

53

,

54

,

55

,

56

,

57

,

58

,

59

]])

/<code>

从输出结果可以看出:

  1. shuffle是按顺序将数据放入buffer里面的;
  2. 当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。

那么,当repeat函数在shuffle之前会怎么样呢?如下:

<code>

data

=

data

.repeat(

2

) ​

data

=

data

.shuffle(buffer_size=

3

) ​

data

=

data

.batch(

4

)/<code>
<code>

In

[25]:

sess.run(batch)

Out[25]:

array([[10,

11

,

12

,

13

,

14

,

15

,

16

,

17

,

18

,

19

],

     

[20,

21

,

22

,

23

,

24

,

25

,

26

,

27

,

28

,

29

],

     

[

0

,

1

,

2

,

3

,

4

,

5

,

6

,

7

,

8

,

9

],

     

[40,

41

,

42

,

43

,

44

,

45

,

46

,

47

,

48

,

49

]])

In

[26]:

sess.run(batch)

Out[26]:

array([[50,

51

,

52

,

53

,

54

,

55

,

56

,

57

,

58

,

59

],

     

[

0

,

1

,

2

,

3

,

4

,

5

,

6

,

7

,

8

,

9

],

     

[30,

31

,

32

,

33

,

34

,

35

,

36

,

37

,

38

,

39

],

     

[30,

31

,

32

,

33

,

34

,

35

,

36

,

37

,

38

,

39

]])

In

[27]:

sess.run(batch)

Out[27]:

array([[10,

11

,

12

,

13

,

14

,

15

,

16

,

17

,

18

,

19

],

     

[50,

51

,

52

,

53

,

54

,

55

,

56

,

57

,

58

,

59

],

     

[20,

21

,

22

,

23

,

24

,

25

,

26

,

27

,

28

,

29

],

     

[40,

41

,

42

,

43

,

44

,

45

,

46

,

47

,

48

,

49

]])

/<code>

可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。


分享到:


相關文章: