__author__ = '糖衣豆豆'
from numpy import *
from os import listdir
import operator
#从列方向扩展
#tile(a,(size,1))
#实现KNN算法,需要指定k,需要测试数据集,需要训练数据集,类别名(标签),
def knn(k,testdata,traindata,labels):
#通过shape获得行数
traindatasize=traindata.shape[0]
#扩展testdata的维数,tile函数可以扩展testdata和traindata相同的行数,然后和traindata的向量相减计算测试机和训练集的差值
dif=tile(testdata,(traindatasize,1))-traindata
#计算差值的平方
sqdif=dif**2
#计算平方和,每一行的各列求和,axis=1每一行的各列求和
sumsqdif=sqdif.sum(axis=1)
#开方
distance=sumsqdif**0.5
#排序
sortdistance=distance.argsort()
#空字典
count={}
#选择距离最短的k
for i in range(0,k):
#获取类别,下标决定属于哪一类
vote=labels[sortdistance[i]]
#整理为一定格式,得到类别vote,每出现一次统计一次
count[vote]=count.get(vote,0)+1
#取出最多的类别,reverse=True表示降序
sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
return sortcount[0][0]
#图片处理
#先将图片转为固定宽高,比如32*32,然后再转为文本
'''
from PIL import Image
im=Image.open("~/Downloads/123.png")
fh=open("~/Downloads/123_txt","a")
width=im.size[0]
height=im.size[1]
#k=im.getpixel((1,9))
#print(k)
for i in range(0,width):
for j in range(0,height):
cl=im.getpixel((i,j))
clall=cl[0]+cl[1]+cl[2]
if(clall==0):
#黑色
fh.write("1")
else:
fh.write("0")
fh.write("\n")
fh.close()
'''
#加载数据
#将数据转为数组
def datatoarray(fname):
arr=[]
fh=open(fname)
#图片是32*32的横轴每次读取32
for i in range(0,32):
thisline=fh.readline()
#读每一行
for j in range(0,32):
#读入到数组里
arr.append(int(thisline[j]))
return arr
arr1=datatoarray("~/coding/python/data/testandtraindata/testdata/0_74.txt")
#print(arr1)
#建立一个函数,取文件的前缀
def seplabel(fname):
filestr=fname.split(".")[0]
label=int(filestr.split("_")[0])
return label
#建立训练数据
def traindata():
#存储类别
labels=[]
#得到训练目录下所有的文件
trainfile=listdir("~/coding/python/data/testandtraindata/traindata")
#取当前文件有多少个
num=len(trainfile)
#生成一个多少行多少列的向量,行的长度应该是32*32=1024(列),每一行存储一个文件
#用一个数组存储所有训练数据,行:文件总数,列:1024
trainarr=zeros((num,1024))
#第一层循环文件
for i in range(0,num):
thisfname=trainfile[i]
#调用seplabel函数
thislabel=seplabel(thisfname)
#存到数组里
labels.append(thislabel)
#调用datatoarray函数,i,:处理重复读取
trainarr[i,:]=datatoarray("~/coding/python/data/testandtraindata/traindata/"+thisfname)
return trainarr,labels
#用测试数据条用KNN算法去测试,看是否能够准确识别
def datatest():
trainarr,labels=traindata()
testlist=listdir("~/coding/python/data/testandtraindata/testdata")
tnum=len(testlist)
for i in range(0,tnum):
thistestfile=testlist[i]
testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)
print(rknn)
#a=datatest()
#print(a)
#抽某一个文件测试文件出来进行验证
trainarr,labels=traindata()
thistestfile="8_15.txt"
testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)
print(rknn)
来源:oschina
链接:https://my.oschina.net/u/4355030/blog/3875360