机器学习之向前传播神经网络:手写数字识别
简述神经网络训练数据部分手写数字展示逻辑函数神经网络多分类结果预测简述
上一篇用逻辑回归实现多分类,该篇将用神经网络进行手写数字的识别。该练习将以octave 作为工具进行实验,逻辑回归数学公式及讲解文档在这里,可点击访问。本篇将使用已经训练好的Theta1、Theta2,实现向前传播算法。下一篇将会使用反向传播算法学习出神经网络参数。神经网络
以下为神经网络简单示意图。包含输入层、隐藏层、输出层。下面训练数据中给到的数据中将会是:输入层400,隐藏层25,输出层10;
训练数据部分手写数字展示
手写数字的数据:
逻辑函数
sigmoid.m 函数将所有实数映射到(0, 1)范围。
%% Sigmoid function
function g = sigmoid(z)
g = zeros(size(z));
g = 1.0 ./ (1.0 + exp(-z));
endfunction
神经网络多分类结果预测
predict.m 用于对输入数据进行预测,预测值最大的即为正确的分类。
%% Neural network prediction function
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
k = size(Theta2, 1);
p = zeros(m, 1);
X = [ones(m, 1) X];
a2 = sigmoid(X * Theta1');
a2 = [ones(m, 1) a2]
a3 = sigmoid(a2 * Theta2');
[a, p] = max(a3, [], 2);
endfunction
ex3weights.mat数据可以在这里进行下载,加载数据后将会得到已经训练好的Theta1、Theta2,以及训练数据(X, y)。
结果预测,准确率达到97.52%
%% =========== 2.Loading Pameters ============
load('ex3weights.mat');
pred = predict(Theta1, Theta2, X);
fprintf('Train data Accuracy: %f\\n', mean(double(pred == y)) * 100);