多项式回归 Python sklearn库 LinearRegression(学习笔记)

。_饼干妹妹 提交于 2020-01-04 16:51:12
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

font=FontProperties(fname=r"C:\Windows\Fonts\msyh.ttc",size=15)#设置中文字体

def runplt():
    plt.figure()
    plt.title('身高与体重一元关系',fontproperties=font)
    plt.xlabel('身高(米)',fontproperties=font)
    plt.ylabel('体重(千克)',fontproperties=font)
    plt.axis([0.5,2,5,85],fontproperties=font)
    plt.grid(True)
    return plt

#导入数据
X_train=[[0.86],[0.96],[1.12],[1.35],[1.55],[1.63],[1.71],[1.78]]
y_train=[[12],[15],[20],[35],[48],[51],[59],[66]]
X_test=[[0.75],[1.08],[1.26],[1.51],[1.6],[1.85]]
y_test=[[10],[17],[27],[41],[50],[75]]

#一元线性回归
plt=runplt()
model=LinearRegression()
model.fit(X_train,y_train)
xx=np.linspace(0,10,100)
yy=model.predict(xx.reshape(xx.shape[0],1))
plt.plot(X_train,y_train,'k.')
plt.plot(xx,yy)

#二次回归
qua_fearurizer=PolynomialFeatures(degree=2)
X_train_qua=qua_fearurizer.fit_transform(X_train)
X_test_qua=qua_fearurizer.transform(X_test)
model2=LinearRegression()
model2.fit(X_train_qua,y_train)
xx_qua=qua_fearurizer.transform(xx.reshape(xx.shape[0],1))
yy_qua=model2.predict(xx_qua)
plt.plot(xx,yy_qua,'r-')
plt.show()

#用测试集模型评估
print("一元线性回归 r^2:%.2f"%model.score(X_test,y_test))
print("二次回归 r^2:%.2f"%model2.score(X_test_qua,y_test))

在这里插入图片描述
一元线性回归 r^2:0.93
二次回归 r^2:0.99

#多次回归对比
k_range=range(2,10)
k_scores=[]
k_scores.append(model.score(X_test,y_test))
for k in k_range:
    k_fearurizer=PolynomialFeatures(degree=k)
    X_train_k=k_fearurizer.fit_transform(X_train)
    X_test_k=k_fearurizer.transform(X_test)
    modelk=LinearRegression()
    modelk.fit(X_train_k,y_train)
    k_scores.append(modelk.score(X_test_k,y_test))
    
for k in range(0,9):
    print("%d项式r^2是%.2f"%(k+1,k_scores[k]))
    
plt.plot([1,2,3,4,5,6,7,8,9],k_scores)
plt.show()

1项式r^2是0.93
2项式r^2是0.99
3项式r^2是0.99
4项式r^2是0.98
5项式r^2是0.97
6项式r^2是0.93
7项式r^2是-3.03
8项式r^2是-0.80
9项式r^2是0.08
在这里插入图片描述
可见二项式和三项式拟合效果最高,后面出现过度拟合。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!