周志华机器学习3_5线性判例分析 编程实现(python&matlab)

牧云@^-^@ 提交于 2020-01-06 21:35:29

作为小白一枚,在参考了大神们的代码(文末附链接)后首先利用python分别用两种方法进行了编程。编程结果与大神们的不一致,不知道是不是中间过程有问题。在阅读过程中如果看出问题,望能指出,谢谢大噶!(以下均将数据集中第15行数据去除后运行)
方法1:(利用sklearn库中的LDA函数)

import pandas as pd## 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
#数据读取
dataset=pd.read_csv('watermelon1.csv')
dataset=dataset.values
X=dataset[:,1:3]
y=dataset[:,3]

#利用西瓜数据绘制散点图
fig=plt.figure()
plt.title('wtatermelon30_a_sklearnlda')
plt.xlabel('density')
plt.xlabel('sugur_ratio')
plt.scatter(X[y==0,0],X[y==0,1],marker='o',c='k',label='bad')
plt.scatter(X[y==1,0],X[y==1,1],marker='o',c='g',label='good')

#lda拟合并计算投影直线斜率
lda_model=LDA(solver='svd',shrinkage=None).fit(X,y)
w=lda_model.coef_
k=w[0,1]/w[0,0]
b=lda_model.intercept_
print(k)

#计算投影点
def projectedpoint(X,Y,k,b):
    k1=-1/k
    b1=Y-k1*X
    x=(b*np.ones(len(b1))-b1)/(k1-k)
    y=k*x+b*np.ones(len(b1))
    #print(x,y)
    return x,y

#绘制投影直线
x2 = np.linspace(-0.9, 0.9, 100)
y2=k*x2+b
plt.plot(x2,y2,color='orange')

x0,y0=projectedpoint(X[y==0,0],X[y==0,1],k,b)
x1,y1=projectedpoint(X[y==1,0],X[y==1,1],k,b)

#绘制垂直虚线
X00=X[y==0,0]
X01=X[y==0,1]
X10=X[y==1,0]
X11=X[y==1,1]
for i in range(len(x0)):
    plt.plot([x0[i],X00[i]],[y0[i],X01[i]],'--',c='k')
for j in range(len(x1)):
    plt.plot([x1[j],X10[j]],[y1[j],X11[j]],'--',c='g')

plt.xlim([-0.2,0.9])
plt.ylim([-0.2,0.9])
plt.show()
plt.legend()

其中有一个问题没有理解的是为什么LDA函数返回截距值,代码原理中只要过零点就可以了。后来仔细想想,截距不为零只是对投影直线进行了左右平移,并不会影响分类结果。(不知道这样想对不对/doge)

方法2:利用周老师书上LDA二分类原理编写(多余绘图、导入数据等代码不再copy)

#计算卷积
def Cov(X,u):
    S=0
    print('u',u)
    for i in range(len(X)):
        p=X[i,:]-u
        t=p.reshape((2,1))
        tt=p.reshape((1,2))
        S=S+np.dot(t,tt)
    return S

#计算Sw和w
def LDA(dataset1,dataset0):
    X1=dataset1[:,0:2]
    X0=dataset0[:,0:2]
    u1=np.array([np.mean(X1[:,0]),np.mean(X1[:,1])])
    u0=np.array([np.mean(X0[:,0]),np.mean(X0[:,1])])
    Cov1=Cov(X1,u1)
    Cov0 = Cov(X0, u0)
    Sw=Cov1+Cov0
    print(Sw)
    u=u0-u1
    p=u.reshape((len(u),1))
    pt=u.reshape((1,len(u)))
    Sb=np.dot(p,pt)

    u,s,v=np.linalg.svd(Sw,1,1) #对Sw奇异值分解
    s=np.diag(s)
    s=np.linalg.inv(s)
    Swinv=np.dot(v.transpose(),s)
    Swinv=np.dot(Swinv,u.transpose())

    w=np.dot(Swinv,u0-u1)
    print(w)
    return w

w=LDA(dataset1,dataset0)

k=w[0,1]/w[0,0] #计算斜率
print(k)

#计算投影点
def projectedpoint(X,Y,k):
    k1=-1/k
    b=Y-k1*X
    x=-b/(k1-k)
    y=k1*x+b
    #print(x,y)
    return x,y

上述python编程运行结果为:

Sw [[0.39934838 0.05480425]
 [0.05480425 0.111877  ]]
斜率: -37.61017098508275

python绘图运行结果:在这里插入图片描述

由于运行结果与大佬的不同,且相差较多,但反复检查实在找不出问题,于是又用matlab根据原理编程了一遍,具体代码如下
(1)主函数:

clear;
clc;
%读取数据
dataset=importdata('watermelon1.csv');
X1=dataset.data(1:8,2:3);
X0=dataset.data(9:16,2:3);
#计算Sw和w,斜率
[w,Sw]=Sw_w(X1,X0)
k=w(2,1)/w(1,1)
x=-0.1:0.01:0.9;
y=k*x;
#计算投影点及投影中心
[x11,x12]=projectedpoint(X1(:,1),X1(:,2),k);
[x01,x02]=projectedpoint(X0(:,1),X0(:,2),k);
[u11,u12]=projectedpoint(mean(X1(:,1)),mean(X1(:,2)),k);
[u01,u02]=projectedpoint(mean(X0(:,1)),mean(X0(:,2)),k);
#绘图
figure(1)
scatter(X1(:,1),X1(:,2),'go');
hold on;
scatter(u11,u12,'go');
hold on ;
scatter(X0(:,1),X0(:,2),'ko');
scatter(u01,u02,'ko');
hold on;
plot(x,y,'y-');
hold on;
for i=1:8
    plot([x11(i,1),X1(i,1)],[x12(i,1),X1(i,2)],'g--');
    hold on;
end
for i=1:8
    plot([x01(i,1),X0(i,1)],[x02(i,1),X0(i,2)],'k--');
    hold on;
end

(2)计算Sw和w

function [w,Sw]=Sw_w(X1,X0)
u0=mean(X0,1)
u1=mean(X1,1);
Sw=zeros(2);
for i=1:8
    Sw=Sw+(X0(i,:)-u0)'*(X0(i,:)-u0);
end
for i=1:8
    Sw=Sw+(X1(i,:)-u1)'*(X1(i,:)-u1);
end
[U,S,V]=svd(Sw);
Sw_inv=V'*inv(S)*U';
w=Sw_inv*(u0-u1)';
end

(3)计算投影点

function [x1,x2]=projectedpoint(X,Y,k)
if k==0
    x2=Y;
    x1=zeros(len(X));
else
    k1=-1/k;
    b=Y-k1*X;
    x1=b/(k-k1);
    x2=k*x1;
end
end

matlab运行结果:

Sw =

    0.3993    0.0548
    0.0548    0.1119

k =

  -37.6102

matlab绘图结果:
在这里插入图片描述
结果还是一样的。不知道是不是代码原理没有理解清楚,如果大噶发现问题,希望能在评论区一起探讨撒~

本文参考以下博主文章:
[1]:https://blog.csdn.net/Elvirangel/article/details/84237089
[2]:https://blog.csdn.net/A993852/article/details/80099258
[2]:https://blog.csdn.net/Snoopy_Yuan/article/details/64443841
以及
西瓜书P60-62

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!