神经网络实现手写数字分类matlab

落花浮王杯 提交于 2020-02-07 17:33:08

1 实验结果

有点糊,将就看一下,一个手写数字的自动识别,识别的准确率大概为94%在这里插入图片描述

2、数据集Minist

下载地址:http://yann.lecun.com/exdb/mnist/在这里插入图片描述
四个文件分别为训练集数据、训练集标签、测试集数据、测试集标签。官方介绍,训练集数据有60000张,测试集数据有10000张。(说明:下载后电脑会自动解压成.ubyte.gz格式),这四个文件不是标准的图片格式,因此我们需要建一个.m文件实现对数据的读取。每张图片都是2828,因此每次读取2828大小为一张图片。

2.1 读取数据转化为向量

将图片向量化为7841,将训练集所有图片向量存在x_train中,大小为78460000,标签存放在y_train中,大小为1*60000(测试集同理,分别为x_test,y_test)
为了后面找到最合适的网络参数,因此直接将读取的训练集、测试集数据存在文件中,后面直接载入文件调用即可。读取训练集、测试集为同一函数,为了将读取到的数据区别开分别存放在train和test文件中,在函数中定义一个描述字符,表示该文件是训练文件还是测试文件。

function  build_dataset(image_file,label_file,describe)
%读取训练集图片文件
images = fopen(image_file,'r'); 
%读取文件说明信息,前十六字节的数据是说明信息
a = fread(images,16,'uint8'); 
% 说明信息是用32位整型表述
%MagicNum = ((a(1)*256+a(2))*256+a(3))*256+a(4);
%前四字节是idx文件格式说明
ImageNum = ((a(5)*256+a(6))*256+a(7))*256+a(8);%图片数量
rows = ((a(9)*256+a(10))*256+a(11))*256+a(12);%行数:28
cols = ((a(13)*256+a(14))*256+a(15))*256+a(16);%列数:28
%读取训练集标签文件
labels =  fopen(label_file,'r'); 
%读取文件说明信息,前八个字节
a1 = fread(labels,8,'uint8'); 
%MagicNum1 = ((a1(1)*256+a1(2))*256+a1(3))*256+a1(4);%idx文件格式说明
ImageNum1 = ((a1(5)*256+a1(6))*256+a1(7))*256+a1(8);%图像数量(四字节,用32位整型表述)
%读取数据信息
Label = zeros(1,ImageNum1);
data = [];
for i=1:ImageNum    
    im = im2double(fread(images,rows*cols,'uint8'));  
    %每次读取一个图片大小的数据,28*28,为向量    
    label= fread(labels,1,'uint8');%对应的标签,该值表示图片中的数字是0~9的哪个数字
    Label(i) = label;
    im_arr = im';%将图片转化为按行存储
    im_vec = reshape(im_arr,rows*cols,1); %转化为矩阵    
    data = [data,im_vec];
end
%根据描述符,存储数据
if describe == "train"
    x_train = data;
    y_train = Label;
    save("e:/minist/train_data","x_train","y_train");
end
if describe == "test"
    x_test = data;
    y_test = Label;
    save("e:/minist/test_data","x_test","y_test");
end
end

2.2 读取部分数据转化为图片

直接载入测试集数据,x_test中每一列表示一张图片,将其reshape为28*28,转化为uint8格式,因为图片中是按列存储,所以将其转置后写入图片。

function build()
load("e:/minist/test_data");
[m,n] = size(x_test);
for k = 1:n
    x = x_test(:,k);
    x = reshape(x,28,28);
    x = uint8(x);
    imwrite(x',"e:/minist/image/"+k+".bmp");
end
end

2、人工神经网络

神经网络选择最简单的三层网络:输入层、隐藏层、输出层。神经网络中涉及两个部分:正向传播和反向传播。

主函数

%获取训练集、测试集
load("e:/minist/train_data");
load("e:/minist/test_data");
%load("e:/minist/data.mat");
%数据归一化到0-1之间
x_train = mapminmax(x_train,0,1);
x_test = mapminmax(x_test,0,1);
a = 0.03;
step = 100;%迭代次数
hid = 27;%隐藏层个数
[w,b,w_h,b_h] = train(x_train,y_train,step,a,hid);
acc = test(x_test,y_test,w,b,w_h,b_h);
fprintf("隐藏层为为:"+hid+" 测试集准确率为: "+acc+"\n");

网络训练

%训练神经网络
function [w,b,w_hid,b_hid] = train(x_train,y_train,step,a,hid)
%神经网络由三层组成,分别为输入层input、隐藏层hidden、输出层output
in = 784;%输入层个数
out = 10;%输出层神经元个数(0-9)
w = randn(hid,in);%输入层--隐藏层权重
b = randn(hid,1);%输入层--隐藏层偏置
w_hid = randn(out,hid);%隐藏层--输入层权重
b_hid = randn(out,1);%隐藏层--输出层偏置
%定义10个数字的正确输出
Y = [1,0,0,0,0,0,0,0,0,0;
     0,1,0,0,0,0,0,0,0,0;
     0,0,1,0,0,0,0,0,0,0;
     0,0,0,1,0,0,0,0,0,0;
     0,0,0,0,1,0,0,0,0,0;
     0,0,0,0,0,1,0,0,0,0;
     0,0,0,0,0,0,1,0,0,0;
     0,0,0,0,0,0,0,1,0,0;
     0,0,0,0,0,0,0,0,1,0;
     0,0,0,0,0,0,0,0,0,1;];
for i=1:step
    %打乱顺序
    r = randperm(60000);
     for j = 1:60000
        %取第j张图片
        x = x_train(:,j);
        y = y_train(1,j);
        %正向传播计算每一层输出
        out_hid = Layerout(w,b,x);%隐藏层输出
        out_y = Layerout(w_hid,b_hid,out_hid);%输出层输出
        %反向传播
        out_update = (Y(:,y+1)-out_y).*out_y.*(1-out_y);%输出层--隐藏层
        hid_update = ((w_hid')*out_update).*out_hid.*(1-out_hid);%隐藏层--输入层
        %更新参数
        w_hid = w_hid + a*(out_update*(out_hid'));
        b_hid = b_hid + a*out_update;
        w = w + a*(hid_update*(x'));
        b = b + a*hid_update;
    end
end
save("e:/minist/data.mat","w","b","w_hid","b_hid");
end

神经元激活函数输出

%每一层激活函数输出
function [y] = Layerout(w,b,x)
y = w*x+b;
n = length(y);
for i =1:n
    y(i) = 1.0/(1+exp(-y(i)));
end
end

预测

function [index] = predict(path)
%UNTITLED3 此处显示有关此函数的摘要
%存放测试图片属于每个数字的概率
load("e:/minist/data.mat");
im = im2double(imread(path));%得到28*28j矩阵ss
im = im';
x = reshape(im,784,1);
hid = Layerout(w,b,x);
res = Layerout(w_hid,b_hid,hid);
[t,index] = max(res);
index = index-1;
end

3、GUI设计

打开界面
在这里插入图片描述
按下start键,开始进行手写数字识别。定义一个可编辑文本框img,用其显示每张图片的路径,同时定义一个文本框img_index表示图片是第几张(方便进行读取图片),并设置其为不可见。三个按钮,start、stop、continue,start键进行初始化,设置img_index=1,并在坐标轴中显示第一张图片。定义一个timer定时器,每过一秒令img_index加1,即每过一秒读取下一张图片,进行识别,将识别结果显示在result文本框中。

在这里插入代码片
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!