卷積神經網絡概述及python實現

對於卷積神經網絡(CNN)而言,相信很多讀者並不陌生,該網絡近年來在大多數領域都表現優異,尤其是在計算機視覺領域中。但是很多工作人員可能直接調用相關的深度學習工具箱搭建卷積神經網絡模型,並不清楚其中具體的原理。本文將簡單介紹卷積神經網絡(CNN),方便讀者大體上了解其基本原理及實現過程,便於後續工作中的實際應用。本文將按以下順序展開:

  • 瞭解卷積操作
  • 瞭解神經網絡
  • 數據預處理
  • 瞭解CNN
  • 瞭解優化器
  • 理解 ImageDataGenerator
  • 進行預測並計算準確性
  • demo

什麼是卷積?

在數學(尤其是函數分析)中,卷積是對兩個函數(f和g)的數學運算,以產生第三個函數,該函數表示一個函數的形狀如何被另一個修改。

此操作在多個領域都有應用,如概率、統計、計算機視覺、自然語言處理、圖像和信號處理、工程和微分方程。該操作在數學上表示為:

卷積神經網絡概述及python實現

卷積操作

什麼是人工神經網絡?

人工神經網絡(ANN)或連接系統是由構成動物大腦的生物神經網絡模糊地啟發的計算系統。這些系統通過從示例中“學習”以執行任務,通常不需要使用用任何特定規則來編程。(來源:維基百科)

人工神經網絡是一個較小的處理單元集合,稱為人工神經元,它們與生物神經元相似。

生物神經迴路

卷積神經網絡概述及python實現

生物神經迴路

神經元之間的互聯構成了一個網絡模型

人工神經網絡

卷積神經網絡概述及python實現

人工神經網絡

現在,我們開始具體實現。

導入必要的數據包

加載數據集

數據集

此處使用的數據集手寫數據集
trainIamges.csv有1024列和13440行。每列表示圖像中的像素,每行表示一張單獨的灰度圖像。每個像素的取值範圍是0到255之間的值。

可視化數據集

訓練數據集

卷積神經網絡概述及python實現

訓練數據集

測試數據集

卷積神經網絡概述及python實現

測試數據集

數據預處理

編碼分類變量

什麼是分類變量?

在統計學中,分類變量是一個可以承擔限制變量之一的變量,基於某些定性屬性將每個個體或其他觀察單元分配給特定組或名義類別。

簡單來說,分類變量的值表示類別或類。

‘’

為什麼需要編碼分類變量?

直接對錶示類別的數字執行操作沒有意義。因此,需要對其進行分類編碼。
阿拉伯字母表中有28個字母。因此,數據集有28個類別。

標準化

什麼是標準化?

進行歸一化以使整個數據進入明確定義的範圍,一般選擇歸一化到0到1之間

在神經網絡中,不僅要對數據進行標準化,還要對其進行標量化,這樣處理的目的是能夠更快地接近錯誤表面的全局最小值。

對其進行變形操作使得每條數據表示一個平面圖像

按功能劃分的零中心將每個樣本的中心置零,並指定平均值。如果未指定,則對所有樣品評估平均值。

建立CNN

最大池化(Max Pooling)是什麼?

池化意味著組合一組數據,組合數據的過程中應該遵循一些規則。

根據定義,最大池化選取一組數據中的最大值作為其輸出值。

最大池還可以用於減小特徵維度,它還可以避免過擬合的發生。以便更好地瞭解Max Pooling。

什麼是Dropout?

Dropout是一種正則化技術,通過防止對訓練數據進行復雜的協同適應來減少神經網絡中的過擬合,這是神經網絡模型中十分有效的方法之一。“ 丟失”指的是在神經網絡中以某一個概率隨機地丟棄部分神經單元。

什麼是Flatten?

對特徵圖進行展平,以將多維數據轉換為一維特徵向量,以供下一層(密集層)使用

什麼是密集層?

密集層只是一層人工神經網絡,也被稱作全連接層。

CNN的優化方法

什麼是優化?

優化算法幫助我們最小化(或最大化)目標函數,目標函數只是一個數學函數,取決於模型內部可學習的參數。模型中使用預測變量集(X)計算目標值(Y)。例如,我們將神經網絡的權重(W)和偏差(b)值稱為其內部可學習參數,用於計算輸出值,並在最優解的方向上學習和更新這些參數,即最小化損失網絡。這就是神經網絡的訓練過程。

本文在這裡使用的優化器是RMSprop,點擊此處以瞭解有關RMSprop的更多信息。

什麼是ImageDataGenerator?

當你的數據集規模比較小時,你可能會應用到圖像數據生成器,它用於生成具有實時增強的批量張量圖像數據,擴大數據集規模。一般而言,當數據量增多時,模型性能會得更好。以下代碼用於批量加載圖像:

CNN擬合訓練數據

做出預測

生成混淆矩陣

什麼是混淆矩陣?

混淆矩陣是用於總結分類算法性能的一種技術。如果每個類別中的觀察數量不等,或者數據集中有兩個以上的類,單獨的分類準確性可能會產生誤導。計算混淆矩陣可以讓我們更好地瞭解分類模型的正確性以及它所犯的錯誤類型。

計算準確性

本文獲得了97%的準確度,感興趣的讀者可以自己嘗試下。

CNN手寫數字識別demo

可以實時查看CNN的工作情況,該demo顯示了CNN的工作過程,以及每層輸出的特徵圖。最後該CNN網絡經過訓練後能夠識別手寫數字。

卷積神經網絡概述及python實現

以上為譯文,由阿里云云棲社區組織翻譯。


分享到:


相關文章: