機器學習之向前傳播神經網絡:手寫數字識別

機器學習之向前傳播神經網絡:手寫數字識別

  • 簡述
  • 神經網絡
  • 訓練數據部分手寫數字展示
  • 邏輯函數
  • 神經網絡多分類結果預測

簡述

  • 上一篇用邏輯迴歸實現多分類,該篇將用神經網絡進行手寫數字的識別。
  • 該練習將以octave 作為工具進行實驗,邏輯迴歸數學公式及講解文檔在這裡,可點擊訪問。
  • 本篇將使用已經訓練好的Theta1Theta2,實現向前傳播算法。下一篇將會使用反向傳播算法學習出神經網絡參數。
機器學習之向前傳播神經網絡:手寫數字識別

神經網絡

以下為神經網絡簡單示意圖。包含輸入層、隱藏層、輸出層。下面訓練數據中給到的數據中將會是:輸入層400,隱藏層25,輸出層10

機器學習之向前傳播神經網絡:手寫數字識別

訓練數據部分手寫數字展示

手寫數字的數據:ex3data1.mat,(https://github.com/peedeep/Coursera/blob/master/ex3/ex3data1.mat)複製鏈接下載,是一個(m*n) m=5000, n=400矩陣數據,m表示訓練數據樣本數,n表示每個數據的特徵維度。且提供了displayData.m 函數來對訓練數據中隨機10*10=100張手寫數字顯示:

機器學習之向前傳播神經網絡:手寫數字識別

邏輯函數

sigmoid.m 函數將所有實數映射到(0, 1)範圍。

機器學習之向前傳播神經網絡:手寫數字識別

%% Sigmoid function
function g = sigmoid(z)
g = zeros(size(z));
g = 1.0 ./ (1.0 + exp(-z));
endfunction
機器學習之向前傳播神經網絡:手寫數字識別

神經網絡多分類結果預測

predict.m 用於對輸入數據進行預測,預測值最大的即為正確的分類。

%% Neural network prediction function
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
k = size(Theta2, 1);
p = zeros(m, 1);
X = [ones(m, 1) X];
a2 = sigmoid(X * Theta1');
a2 = [ones(m, 1) a2]
a3 = sigmoid(a2 * Theta2');

[a, p] = max(a3, [], 2);
endfunction

ex3weights.mat數據可以在這裡進行下載,加載數據後將會得到已經訓練好的Theta1Theta2,以及訓練數據(X, y)

結果預測,準確率達到97.52%

%% =========== 2.Loading Pameters ============
load('ex3weights.mat');
pred = predict(Theta1, Theta2, X);
fprintf('Train data Accuracy: %f\\n', mean(double(pred == y)) * 100);


分享到:


相關文章: