原作 Margaret Maynard-Reid
王小新 編譯自 TensorFlow的Medium
這篇教程會介紹如何用TensorFlow裡的tf.keras函數,對Fashion-MNIST數據集進行圖像分類。
只需幾行代碼,就可以定義和訓練模型,甚至不需要太多優化,在該數據集上的分類準確率能輕鬆超過90%。
在進入正題之前,我們先介紹一下上面提到的兩個名詞:
Fashion-MNIST,是去年8月底德國研究機構Zalando Research發佈的一個數據集,其中訓練集包含60000個樣本,測試集包含10000個樣本,分為10類。樣本都來自日常穿著的衣褲鞋包,每一個都是28×28的灰度圖。
這個數據集致力於成為手寫數字數據集MNIST的替代品,可用作機器學習算法的基準測試,也同樣適合新手入門。
連LeCun都推薦的Fashion-MNIST數據集,是這位華人博士的成果
或者去GitHub:
https://github.com/zalandoresearch/fashion-mnist
tf.keras是用來在TensorFlow中導入Keras的函數。Keras是個容易上手且深受歡迎的深度學習高級庫,是一個獨立開源項目。在TensorFlow中,可以使用tf.keras函數來編寫Keras程序,這樣就能充分利用動態圖機制eager execution和tf.data函數。
下面可能還會遇到其他深度學習名詞,我們就不提前介紹啦。進入正題,教你用tf.keras完成Fashion-MNIST數據集的圖像分類~
運行環境
無需設置,只要使用Colab直接打開這個Jupyter Notebook鏈接,就能找到所有代碼。
https://colab.research.google.com/github/margaretmz/deep-learning/blob/master/fashion_mnist_keras.ipynb
數據處理
Fashion-MNIST數據集中有十類樣本,標籤分別是:
- T恤 0
- 褲子 1
- 套頭衫 2
- 裙子 3
- 外套 4
- 涼鞋 5
- 襯衫 6
- 運動鞋 7
- 包 8
- 踝靴 9
數據集導入
下面是數據集導入,為後面的訓練、驗證和測試做準備。
只需一行代碼,就能用keras.datasets接口來加載fashion_mnist數據,再用另一行代碼來載入訓練集和測試集。
數據可視化
我喜歡用Jupyter Notebook來可視化,你也可以用matplotlib庫中imshow函數來可視化訓練集中的圖像。要注意,每個圖片都是大小為28x28的灰度圖。
數據歸一化
接著,進行數據歸一化,使得樣本值都處於0到1之間。
數據集劃分
這個數據集一共包含60000個訓練樣本和10000個測試樣本,我們會把訓練樣本進一步劃分為訓練集和驗證集。下面是深度學習中三種數據的作用:
- 訓練數據,用來訓練模型;
- 驗證數據,用來調整超參數和評估模型;
- 測試數據,用來衡量最優模型的性能。
模型構建
下面是定義和訓練模型。
模型結構
在Keras中,有兩種模型定義方法,分別是序貫模型和功能函數。
在本教程中,我們使用序貫模型構建一個簡單CNN模型,用了兩個卷積層、兩個池化層和一個Dropout層。
要注意,第一層要定義輸入數據維度。最後一層為分類層,使用Softmax函數來分類這10種數據。
模型編譯
在訓練模型前,我們用model.compile函數來配置學習過程。在這裡,要選擇損失函數、優化器和訓練測試時的評估指標。
模型訓練
訓練模型時,Batch Size設為64,Epoch設為10。
測試性能
訓練得到的模型在測試集上的準確率超過了90%。
預測可視化
我們通過datasetmodel.predict(x_test)函數,用訓練好的模型對測試集進行預測並可視化預測結果。當標籤為紅色,則說明預測錯誤;當標籤為綠色,則說明預測正確。下圖為15個測試樣本的預測結果。
相關鏈接
最後,在這篇普通的入門教程基礎上,還有一些提升之路:
如果想深入瞭解本文使用的Google Colab,可以看這份官方介紹:
https://medium.com/tensorflow/colab-an-easy-way-to-learn-and-use-tensorflow-d74d1686e309
如果你是深度學習初學者,MNIST也應該瞭解一下。之前TensorFlow有一篇MNIST教程,可以拿來和本文比較一下,你就會發現,深度學習現在已經變得簡單了很多:
https://www.tensorflow.org/versions/r1.1/get_started/mnist/beginners
本文用到的是Keras裡的序貫模型,如果對功能函數感興趣,可查看這篇用Keras功能函數和TensorFlow來預測葡萄酒價格的博文:
https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03
— 完 —
誠摯招聘
վ'ᴗ' ի 追蹤AI技術和產品新動態
閱讀更多 量子位 的文章