京东金融大数据竞赛猪脸识别(5)- 识别方法之一

余生长醉 提交于 2019-11-26 02:08:28

自编码器是早期的神经网络方法之一了,为便于了解各方法识别性能,我们首先用它进行识别。代码如下:

clear
load('JDPig_mlhmslbp_spyr.mat');
m = numel(classe_name);
n = length(y);
label = []
%one-hot编码,将每幅图像的数字类别标签变为只有1个值为0的向量,向量维数与类别个数相同
for i=1:n
    label(:,i) = zeros(m,1);
    label(y(i),i) = 1;
end
testImg  = load('JDTest_mlhmslbp_spyr.mat');
hiddenSize = 10;
%训练自编码器
if ~exist('autoenc.mat')      
    autoenc1 = trainAutoencoder(X,hiddenSize,...
        'L2WeightRegularization',0.001,...
        'SparsityRegularization',4,...
        'SparsityProportion',0.05,...
        'DecoderTransferFunction','purelin');
    features1 = encode(autoenc1,X);
    fprintf('saving features1\n');
    %存储自编码器结果
    save('autoenc.mat','features1','autoenc1');
else
    load('autoenc.mat','features1','autoenc1');
    fprintf('loading features1\n');
end
%训练第二层自编码器
if ~exist('deepnetAutoenc.mat')   
    hiddenSize = 10;
    autoenc2 = trainAutoencoder(features1,hiddenSize,...
        'L2WeightRegularization',0.001,...
        'SparsityRegularization',4,...
        'SparsityProportion',0.05,...
        'DecoderTransferFunction','purelin',...
        'ScaleData',false);
    features2 = encode(autoenc2,features1);
    %训练softmax分类层
    softnet = trainSoftmaxLayer(features2,label,'LossFunction','crossentropy');
    %构造深度网络
    deepnet = stack(autoenc1,autoenc2,softnet);
    deepnet = train(deepnet,X,label);
    %存储第二层自编码器和训练出的深度网络
    save('deepnetAutoenc.mat','autoenc2','features2','softnet','deepnet');
    fprintf('Saving features2\n');
else
    load('deepnetAutoenc.mat','autoenc2','features2','softnet','deepnet'); 
    fprintf('Loading fatures2\n');
end
%计算每幅图像对于各个类的得分
scores = deepnet(testImg.X);
fprintf('Testing images!\n');
load('testName.mat','imgName');
for i=1:length(scores)
    for j=1:m
        indImg((i-1)*m+j) = imgName(i);
        plabel((i-1)*m+j)  = j;
        prob((i-1)*m+j)   = scores(j,i);
    end
end
%创建各图像属于每个类的概率表
T = table(indImg',plabel',prob');
%将概率表存为csv文件,用以网站上传并计算最终得分
writetable(T,'resAutoenc.csv');
fprintf('Image recognition finished!\n');

最后写入磁盘的csv文件在上传至京东金融网站后,网站会计算出得分。该网络结构简单,准确率一般,印象中得分在10左右。不过该方法在图像集上测试的结果,准确率也在95%以上。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!