手把手教圖像分類入門,輕鬆實現90%準確率

原作 Margaret Maynard-Reid

王小新 編譯自 TensorFlow的Medium

這篇教程會介紹如何用TensorFlow裡的tf.keras函數,對Fashion-MNIST數據集進行圖像分類。

只需幾行代碼,就可以定義和訓練模型,甚至不需要太多優化,在該數據集上的分類準確率能輕鬆超過90%。

手把手教圖像分類入門,輕鬆實現90%準確率

在進入正題之前,我們先介紹一下上面提到的兩個名詞:

Fashion-MNIST,是去年8月底德國研究機構Zalando Research發佈的一個數據集,其中訓練集包含60000個樣本,測試集包含10000個樣本,分為10類。樣本都來自日常穿著的衣褲鞋包,每一個都是28×28的灰度圖。

這個數據集致力於成為手寫數字數據集MNIST的替代品,可用作機器學習算法的基準測試,也同樣適合新手入門。

連LeCun都推薦的Fashion-MNIST數據集,是這位華人博士的成果

或者去GitHub:

https://github.com/zalandoresearch/fashion-mnist

tf.keras是用來在TensorFlow中導入Keras的函數。Keras是個容易上手且深受歡迎的深度學習高級庫,是一個獨立開源項目。在TensorFlow中,可以使用tf.keras函數來編寫Keras程序,這樣就能充分利用動態圖機制eager execution和tf.data函數。

下面可能還會遇到其他深度學習名詞,我們就不提前介紹啦。進入正題,教你用tf.keras完成Fashion-MNIST數據集的圖像分類~

運行環境

無需設置,只要使用Colab直接打開這個Jupyter Notebook鏈接,就能找到所有代碼。

https://colab.research.google.com/github/margaretmz/deep-learning/blob/master/fashion_mnist_keras.ipynb

數據處理

Fashion-MNIST數據集中有十類樣本,標籤分別是:

  • T恤 0
  • 褲子 1
  • 套頭衫 2
  • 裙子 3
  • 外套 4
  • 涼鞋 5
  • 襯衫 6
  • 運動鞋 7
  • 包 8
  • 踝靴 9

數據集導入

下面是數據集導入,為後面的訓練、驗證和測試做準備。

只需一行代碼,就能用keras.datasets接口來加載fashion_mnist數據,再用另一行代碼來載入訓練集和測試集。

數據可視化

我喜歡用Jupyter Notebook來可視化,你也可以用matplotlib庫中imshow函數來可視化訓練集中的圖像。要注意,每個圖片都是大小為28x28的灰度圖。

手把手教圖像分類入門,輕鬆實現90%準確率

數據歸一化

接著,進行數據歸一化,使得樣本值都處於0到1之間。

數據集劃分

這個數據集一共包含60000個訓練樣本和10000個測試樣本,我們會把訓練樣本進一步劃分為訓練集和驗證集。下面是深度學習中三種數據的作用:

  • 訓練數據,用來訓練模型;
  • 驗證數據,用來調整超參數和評估模型;
  • 測試數據,用來衡量最優模型的性能。

模型構建

下面是定義和訓練模型。

模型結構

在Keras中,有兩種模型定義方法,分別是序貫模型和功能函數。

在本教程中,我們使用序貫模型構建一個簡單CNN模型,用了兩個卷積層、兩個池化層和一個Dropout層。

要注意,第一層要定義輸入數據維度。最後一層為分類層,使用Softmax函數來分類這10種數據。

模型編譯

在訓練模型前,我們用model.compile函數來配置學習過程。在這裡,要選擇損失函數、優化器和訓練測試時的評估指標。

模型訓練

訓練模型時,Batch Size設為64,Epoch設為10。

測試性能

訓練得到的模型在測試集上的準確率超過了90%。

預測可視化

我們通過datasetmodel.predict(x_test)函數,用訓練好的模型對測試集進行預測並可視化預測結果。當標籤為紅色,則說明預測錯誤;當標籤為綠色,則說明預測正確。下圖為15個測試樣本的預測結果。

手把手教圖像分類入門,輕鬆實現90%準確率

相關鏈接

最後,在這篇普通的入門教程基礎上,還有一些提升之路:

如果想深入瞭解本文使用的Google Colab,可以看這份官方介紹:

https://medium.com/tensorflow/colab-an-easy-way-to-learn-and-use-tensorflow-d74d1686e309

如果你是深度學習初學者,MNIST也應該瞭解一下。之前TensorFlow有一篇MNIST教程,可以拿來和本文比較一下,你就會發現,深度學習現在已經變得簡單了很多:

https://www.tensorflow.org/versions/r1.1/get_started/mnist/beginners

本文用到的是Keras裡的序貫模型,如果對功能函數感興趣,可查看這篇用Keras功能函數和TensorFlow來預測葡萄酒價格的博文:

https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03

誠摯招聘

վ'ᴗ' ի 追蹤AI技術和產品新動態


分享到:


相關文章: