python 手写体数字识别

给你一囗甜甜゛ 提交于 2020-01-16 10:16:32
from os import listdir
from numpy import *
import time
import operator
def classify(inputPoint,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]     #已知分类的数据集(训练集)的行数
    #先tile函数将输入点拓展成与训练集相同维数的矩阵,再计算欧氏距离
    diffMat = tile(inputPoint,(dataSetSize,1))-dataSet  #样本与训练集的差值矩阵
    sqDiffMat = diffMat ** 2                    #差值矩阵平方
    sqDistances = sqDiffMat.sum(axis=1)         #计算每一行上元素的和
    distances = sqDistances ** 0.5              #开方得到欧拉距离矩阵
    sortedDistIndicies = distances.argsort()    #按distances中元素进行升序排序后得到的对应下标的列表
    #选择距离最小的k个点
    classCount = {}
    for i in range(k):
        voteIlabel = labels[ sortedDistIndicies ]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    #按classCount字典的第2个元素(即类别出现的次数)从大到小排序
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]
#文本向量化 32x32 -> 1x1024
def img2vecor(filename):
    returnVect=[]
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect.append(int(lineStr[j]))
    return returnVect
#从文件名中解析分类数字
def classnumCut(filename):
    fileStr=filename.split('.')[0]
    classNumStr=int(fileStr.split('_')[0])
    return classNumStr
#构建训练集数据向量,及对应分类标签向量
def trainingDataSet(file):
    hwLabels=[]
    trainingFileList=listdir(file) #获取目录内容
    m=len(trainingFileList)
    trainingMat=zeros((m,1024))  #获取m维向量的训练集
    for i in range(m):
        fileNameStr=trainingFileList
        hwLabels.append(classnumCut(fileNameStr))
        trainingMat[i,:]=img2vecor(file+'/%s' % fileNameStr)
    return hwLabels,trainingMat

def handwritingTest(file):
    hwLabels,trainingMat = trainingDataSet('/root/python_test/data/data/trainingDigits')    #构建训练集
    testFileList = listdir(file)        #获取测试集
    errorCount = 0.0                            #错误数
    mTest = len(testFileList)                   #测试集总样本数
    t1 = time.time()
    for i in range(mTest):
        fileNameStr = testFileList
        classNumStr = classnumCut(fileNameStr)
        vectorUnderTest=img2vecor(file+'/%s' % fileNameStr)
        classifierResult =classify(vectorUnderTest,trainingMat,hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr):
            errorCount += 1.0
        print ("\nthe total number of tests is: %d" % mTest)               #输出测试总样本数
        print ("the total number of errors is: %d" % errorCount)           #输出测试错误样本数
        print("the total error rate is: %f" % (errorCount/float(mTest)))  #输出错误率
    t2 = time.time()
    print("Cost time: %.2fmin, %.4fs."%((t2-t1)//60,(t2-t1)%60) )     #测试耗时

if __name__ == "__main__":
    handwritingTest('/root/python_test/data/data/testDigits')
--------------------- 
作者:wx_411180165 
来源:CSDN 
原文:https://blog.csdn.net/qq_24726509/article/details/84923274 
版权声明:本文为博主原创文章,转载请附上博文链接!
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!