class sklearn.linear_model.
LinearRegression
(fit_intercept=True, normalize=False, copy_X=True, n_jobs=None)
参数:
fit_intercept 布尔值,可选,默认为真
是否计算模型的截距。如果设定为假,则不计算就截距。
normalize:布尔值,可选,默认为假
当fit_intercept设定为假的时候自动忽略这个参数。如果为真,X在回归将会进行标准化预处理。如果你希望将数据正则化,可以选用正则化方法事先处理数据并把normalize设定为False.
copy_X 布尔值,可选,默认为真
如果为真,X将被复制,否则,它可能重写。
n_jobs:线程数。
属性:
coef_:系数
rank_:矩阵X的秩
singular_ :
intercept_:截距
方法
|
训练模型 |
|
获得参数 |
|
预测 |
|
返回判定系数R**2 |
|
设定参数 |
一个简单例子
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import metrics
from sklearn.linear_model import LinearRegression
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文名称
plt.rcParams['axes.unicode_minus']=False #正常显示负号
#载入数据集
diabetes_X,diabetes_y=datasets.load_diabetes(return_X_y=True)
#选取其中一个特征
diabetes_X=diabetes_X[:,np.newaxis,2]
#划分训练集和测试集
X_train=diabetes_X[:-30]
X_test=diabetes_X[30:]
y_train=diabetes_y[:-30]
y_test=diabetes_y[30:]
#训练模型
reg=LinearRegression()
reg.fit(X_train,y_train)
y_pred=reg.predict(X_test)
#评估模型得分
# 系数
print('Coefficients: \n', reg.coef_)
# 均方误差
print('Mean squared error: %.2f'
% metrics.mean_squared_error(y_test, y_pred))
# 判定系数
print('Coefficient of determination: %.2f'
% metrics.r2_score(y_test, y_pred))
#画图
plt.figure()
plt.scatter(X_test,y_test,marker='.',color='blue')
plt.plot(X_test,y_pred,color='r')
plt.title('图形')
plt.xticks()
plt.yticks()
plt.show()
Coefficients:
[941.43097333]
Mean squared error: 3878.52
Coefficient of determination: 0.36
不同的系数对岭回归模型的影响
import numpy as np
from sklearn import linear_model
import matplotlib.pyplot as plt
#创建10*10的希尔伯特矩阵
X=1.0/(np.arange(1,11)+np.arange(0,10)[:,np.newaxis])
y=np.ones(10)
#计算路径
n_alphas=200
alphas=np.logspace(-10,-2,n_alphas)
coefs=[]
for alpha in alphas:
ridge=linear_model.Ridge(alpha=alpha,fit_intercept=False)
ridge.fit(X,y)
coefs.append(ridge.coef_)
#画图
ax=plt.gca()
ax.plot(alphas,coefs)
ax.set_xscale('log')
ax.set_xlim(ax.get_xlim()[::-1]) #翻转轴
plt.xlabel('alpha')
plt.ylabel('coef')
plt.title('Ridge coefficients as a function of the regularization')
plt.axis('tight')
plt.show()
来源:CSDN
作者:慢慢悠悠we
链接:https://blog.csdn.net/Graceguanguan/article/details/104576141