機器學習之向前傳播神經網絡:手寫數字識別
- 簡述
- 神經網絡
- 訓練數據部分手寫數字展示
- 邏輯函數
- 神經網絡多分類結果預測
簡述
- 上一篇用邏輯迴歸實現多分類,該篇將用神經網絡進行手寫數字的識別。
- 該練習將以octave 作為工具進行實驗,邏輯迴歸數學公式及講解文檔在這裡,可點擊訪問。
- 本篇將使用已經訓練好的Theta1、Theta2,實現向前傳播算法。下一篇將會使用反向傳播算法學習出神經網絡參數。
神經網絡
以下為神經網絡簡單示意圖。包含輸入層、隱藏層、輸出層。下面訓練數據中給到的數據中將會是:輸入層400,隱藏層25,輸出層10;
訓練數據部分手寫數字展示
手寫數字的數據:ex3data1.mat,(https://github.com/peedeep/Coursera/blob/master/ex3/ex3data1.mat)複製鏈接下載,是一個(m*n) m=5000, n=400矩陣數據,m表示訓練數據樣本數,n表示每個數據的特徵維度。且提供了displayData.m 函數來對訓練數據中隨機10*10=100張手寫數字顯示:
邏輯函數
sigmoid.m 函數將所有實數映射到(0, 1)範圍。
%% Sigmoid function
function g = sigmoid(z)
g = zeros(size(z));
g = 1.0 ./ (1.0 + exp(-z));
endfunction
神經網絡多分類結果預測
predict.m 用於對輸入數據進行預測,預測值最大的即為正確的分類。
%% Neural network prediction function
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
k = size(Theta2, 1);
p = zeros(m, 1);
X = [ones(m, 1) X];
a2 = sigmoid(X * Theta1');
a2 = [ones(m, 1) a2]
a3 = sigmoid(a2 * Theta2');
[a, p] = max(a3, [], 2);
endfunction
ex3weights.mat數據可以在這裡進行下載,加載數據後將會得到已經訓練好的Theta1、Theta2,以及訓練數據(X, y)。
結果預測,準確率達到97.52%
%% =========== 2.Loading Pameters ============
load('ex3weights.mat');
pred = predict(Theta1, Theta2, X);
fprintf('Train data Accuracy: %f\\n', mean(double(pred == y)) * 100);
閱讀更多 無名開發者 的文章