深度學習-LSTM算法實現(MNIST手寫數字識別)

MNIST數據集是機器學習入門的經典數據集,本文將以MNIST手寫數字數據集為例,使用深度學習方法,訓練手寫數字識別模型。對想進行深度學習的同學來說是非常好的練手例子,全文代碼關鍵點都有註釋,自行練習時可以嘗試修改其中的迭代次數和訓練精度,以感受其訓練過程。

MNIST數據集鏈接:http://yann.lecnn.com/exdb/mnist ,共包括4個(.gz)壓縮文件。

深度學習-LSTM算法實現(MNIST手寫數字識別)

MNIST官方網站內容

下載完之後在home或其他地方新建文件夾,

使用:gzip -d [filename] 指令依次解壓4個文件

代碼環境:Ubuntu18.04,Pycharm,TensorFolw2.0

深度學習-LSTM算法實現(MNIST手寫數字識別)

手寫數字圖片示例,表示第21429張,數字為1

接下來進入正題,檢查數據集,查看數據類型等,為訓練做準備。

第1步 查看數據集的內容及大小。

<code>from tensorflow.examples.tutorials.mnist import input_data/<code>

Tensorflow中對MNIST數據集專門的封裝,方便數據處理。

<code>data_dir = "/home/name/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)/<code>

分別為數據集文件路徑和讀取MNIST數據的函數。

<code>print(mnist.train.images.shape)  #訓練數據大小
print(mnist.train.labels.shape) #標籤
print(mnist.test.images.shape)
print(mnist.test.labels.shape)/<code>

如果以上配置正確則不會報錯,並輸出以下結果(注意運行中會調用Tensorflow出現警告等信息可以忽略):

(55000,784)

(55000,10)

(10000,784)

(10000,10)

以上數據表示訓練集手寫數字圖片有55000張,大小為784(28*28)個像素點,標籤為(0-9)10個數字,類型同樣於測試集。但需要注意的是他們都是將圖片展開後的一維向量。

第2步 查看數據集

如果想查看一下里面某張圖片的數字是多少?他的標籤是多少?這裡也有一小段代碼可以順序查看所有數據集中的圖片,並把圖片的編號和代表的數字顯示出來,這裡要用到matplotlib圖形庫函數和numpy函數庫(需注意參數給小一些,看幾張就可以了),代碼如下:

<code>from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
import random
data_dir = "/home/quan/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)
n = 20
for i in range(n):
print("The number of picture is %i !"%i)
plt.imshow(mnist.train.images[i].reshape((28,28)),cmap='gray')
plt.title("%i"%np.argmax(mnist.train.labels[i]))
print(np.argmax(mnist.train.labels[i]))
time.sleep(1)
plt.show()
print("Finished!")/<code>

其中函數np.argmax()是取一組數據中的最大值。

第3步 使用RNN循環神經網絡訓練模型

由於循環神經網絡每個時刻讀取圖片中的1行,即每個時刻需要讀取的數據向量長度為28,那麼讀完整張圖片需要讀取28行。

LSTM結構搭建:

(1) 定義輸入、輸出placeholder:

<code>tf_x = tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])
image = tf.reshape(tf_x,[-1,TIME_STEP,INPUT_SIZE])
tf_y = tf.placeholder(tf.int32,[None,10])/<code>

(2) 定義LSTM結構:

<code>rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs,(h_c,h_n) = tf.nn.dynamic_rnn(
rnn_cell,
image,
initial_state = None,
dtype = tf.float32,
time_major =False
)
output = tf.layers.dense(outputs[:,-1,:],10)/<code>

(3) 定義代價函數:

<code>loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y,logits=output)/<code>

(4) 定義訓練過程及訓練精度:

<code>LR =0.01  #定義學習效率
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
tf.metrics.accuracy(labels=tf.argmax(tf_y,axis=1),predictions=tf.argmax(output,axis=1),)[1]/<code>

第4步 開始完整的訓練過程

<code>from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time
import random
data_dir = "/home/quan/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)
print(time.asctime())
tf.set_random_seed(1)
np.random.seed(1)
#定義超參數
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE =28
LR =0.01 #定義學習效率
a =0

#讀入數據
test_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]
print(mnist.train.images.shape)
print(mnist.train.labels.shape)
b = mnist.train.images.shape[0]
# print(b)
i = random.randint(0,b)
print("The picture is %i."%i)
plt.imshow(mnist.train.images[i].reshape((28,28)),cmap="gray")
plt.title("$The number is %i, num=%i$"%(i,np.argmax(mnist.train.labels[i])))
plt.show()
#定義表示x的向量的tensorflow placeholder
tf_x = tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])
image = tf.reshape(tf_x,[-1,TIME_STEP,INPUT_SIZE])
tf_y = tf.placeholder(tf.int32,[None,10])
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs,(h_c,h_n) = tf.nn.dynamic_rnn(
rnn_cell,
image,
initial_state = None,
dtype = tf.float32,
time_major =False
)
output = tf.layers.dense(outputs[:,-1,:],10)
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y,logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y,axis=1),predictions=tf.argmax(output,axis=1),)[1]
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
num = 0
for step in range(12000):
b_x,b_y = mnist.train.next_batch(BATCH_SIZE)
_,loss_ = sess.run([train_op,loss],{tf_x:b_x,tf_y:b_y})
# print("All steps is %step..."%step)
if step % 50 == 0:
num +=1
print("The steps is: %d"%num)
accuracy_ = sess.run(accuracy,{tf_x:test_x,tf_y:test_y})
print("train loss:%.6f"%loss_,"|test accuracy:%.6f" %accuracy_)
test_output = sess.run(output,{tf_x:test_x[:100]})
pred_y = np.argmax(test_output,1)
print(pred_y,"prediction number.")
print(np.argmax(test_y[:100],1),"real number.")/<code>

說明:其中在訓練開始時添加了時間戳,用到time時間模塊,並且訓練開始時從訓練集中隨機選取一張手寫數據圖片將其位置和代表的數字顯示出來,便於知道數據讀取是不是正常。


深度學習-LSTM算法實現(MNIST手寫數字識別)

程序運行時隨機顯示出來的圖片

代碼中將學習率設為0.01,訓練數據設置為12000個,沒有全部用主要耗費時間,將每50個為一組迭代完成後顯示迭代次數和當前的訓練的精度,精度保留了6位小數,最終訓練完成後精度達到0.965615,如下圖訓練中:迭代了350張圖片可以達到的精度0.744857。

深度學習-LSTM算法實現(MNIST手寫數字識別)

LSTM訓練過程中

選取了測試集數據中的100個進行了測試,非常的準確。訓練結果如下圖所示:

深度學習-LSTM算法實現(MNIST手寫數字識別)

最終訓練結果

有沒有get到,快去動手練習吧。


分享到:


相關文章: