机器学习之线性回归

痞子三分冷 提交于 2020-01-18 08:11:21

主要是对吴恩达机器学习的视频来学习,学习了线性回归内容,总体进行复盘总结;

1、一元线性回归
回归模型是表示输入变量到输出变量之间映射的函数,回归问题等价于函数拟合,回归函数的求解最常用的代价函数是平方损失函数,平方损失函数可以用最小二乘法进行解决,本例中使用梯度下降法进行处理;
梯度下降法:优点:可以处理复杂的目标函数
缺点:如果代价函数不是凸函数,则容易得到局部最优解
学习率的选择,过小导致收敛速度慢,过大则会造成震荡

方法一:自己编写python代码实现

x = [1,2,3,4,5,6]
y = [7,8,9,10,11,12]
#x,y,数据,a 学习率,m迭代轮数
def linerRegre(x,y,a,m):
    theta0 = 0
    theta1 = 0
    diff = [0,0]
    l = len(x)
    while m >0:
        for i in range(l):
            diff[0] +=theta0+theta1*x[i] - y[i]
            diff[1] += theta0*x[i]+theta1*x[i]*x[i] - x[i]*y[i]
        theta0 = theta0- a/l*diff[0]
        theta1 = theta1 - a/l*diff[1]
        error = 0
        for j in range(l):
            error += (theta0+theta1*x[j] - y[j])**2
        if m%50 ==0:
            print(error)
        m-=1
    return [theta0,theta1]

print(linerRegre(x,y,0.001,2000))

每训练50轮输出当前的误差,将此模型训练2000轮中可以看到,误差是先变小,再变大,不是想象中一直减小的,尝试了不同的学习率,会有同样的问题,不知道是啥原因;
后续优化:加入一个误差界限值,当模型误差小于这个界限值时跳出循环;

方法二:使用现成的包
直接调用sklearn的linear_model模块的LinearRegression内容,如何使用linear_model这个模块可以直接看文档内容

https://scikit-learn.org/stable/modules/classes.html

from sklearn.linear_model import LinearRegression
import numpy as np 
x =np.array( [1,2,3,4,5,6])
y = np.array([7,8,9,10,11,12])
x = x.reshape(-1,1)
reg = LinearRegression().fit(x,y)
theta0 = reg.coef_
theta1 = reg.intercept_

2、逻辑斯蒂回归
逻辑斯蒂回归虽然名字中带有回归两字,但其实该模型是用来进行分类处理的,理论知识的推导部分可以看吴恩达的机器学习视频或者周志华的西瓜书,简单实用sklearn实现:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

#X,y = 西瓜书3a数据,数据不需要标准化
#划分训练集和测试集
X_train,X_test,y_train,y_text = train_test_split(X,y,test_size= 0.3,random_state= 30)

clf = LogisticRegression(random_state=0).fit(X_train, y_train)
y_pre = clf.predict(X_test)
#检验预测的准确性
for i in range(len(y_pre)):
    if y_pre[i] == y_test:
        ans+=1
        

使用sklearn可以比较容易实现,如果想自己编码实现也是可以的,周志华的西瓜书里面每一步的推导及结果都有,直接把公式转化成代码即可,西瓜书中求解是使用极大似然函数;吴恩达的机器学习视频里面是使用代价函数,要注意的一点是在使用梯度下降法时,代价函数需要是凸函数,正是这点吴恩达的视频里构造了代价函数,此处构造的代价函数和西瓜书的极大似然的结果是一致的。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!