#源码如下: 批量梯度下降法
import numpy as np
# Setting a random seed, feel free to change it and see different solutions.
np.random.seed(42)
# TODO: Fill in code in the function below to implement a gradient descent
# step for linear regression, following a squared error rule. See the docstring
# for parameters and returned variables.
def MSEStep(X, y, W, b, learn_rate = 0.005):
"""
This function implements the gradient descent step for squared error as a
performance metric.
Parameters
X : array of predictor features
y : array of outcome values
W : predictor feature coefficients
b : regression function intercept
learn_rate : learning rate
Returns
W_new : predictor feature coefficients following gradient descent step
b_new : intercept following gradient descent step
"""
# Fill in code
y_pred = np.matmul(X, W) + b
print("np.matmul(X, W) shi ge sha:",np.matmul(X, W) ,"np.matmul(X, W)=",np.matmul(X, W).shape)
error = y - y_pred
# compute steps
W_new = W + learn_rate * np.matmul(error, X)
print("np.matmul(error, X).shape=",np.matmul(error, X).shape," W.shape=",W.shape,"err.shape=",error.shape,"X.shape=",X.shape)
b_new = b + learn_rate * error.sum()
return W_new, b_new
# The parts of the script below will be run when you press the "Test Run"
# button. The gradient descent step will be performed multiple times on
# the provided dataset, and the returned list of regression coefficients
# will be plotted.
def miniBatchGD(X, y, batch_size = 20, learn_rate = 0.005, num_iter = 25):
"""
This function performs mini-batch gradient descent on a given dataset.
Parameters
X : array of predictor features
y : array of outcome values
batch_size : how many data points will be sampled for each iteration
learn_rate : learning rate
num_iter : number of batches used
Returns
regression_coef : array of slopes and intercepts generated by gradient
descent procedure
"""
n_points = X.shape[0]
W = np.zeros(X.shape[1]) # coefficients
b = 0 # intercept
print("typex=",type(X),"typeW=",type(W))
print("type(y)=",type(y),"typeB=",type(b))
print("X.shape[0]=",X.shape[0])
print("X=",X,"w=",W)
# run iterations
regression_coef = [np.hstack((W,b))]
for _ in range(num_iter):
batch = np.random.choice(range(n_points), batch_size)
if _==0:
print("type(batch)",type(batch))
X_batch = X[batch,:]
y_batch = y[batch]
W, b = MSEStep(X_batch, y_batch, W, b, learn_rate)
regression_coef.append(np.hstack((W,b)))
return regression_coef
if __name__ == "__main__":
# perform gradient descent
data = np.loadtxt('data.csv', delimiter = ',')
X = data[:,:-1]
y = data[:,-1]
regression_coef = miniBatchGD(X, y)
# plot the results
import matplotlib.pyplot as plt
plt.figure()
X_min = X.min()
X_max = X.max()
counter = len(regression_coef)
for W, b in regression_coef:
counter -= 1
color = [1 - 0.92 ** counter for _ in range(3)]
plt.plot([X_min, X_max],[X_min * W + b, X_max * W + b], color = color)
plt.scatter(X, y, zorder = 3)
plt.show()
#$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
#有几点疑惑澄清:
1 矩阵相乘遇到shape为一维数组,即(n,)类似形式实际是当作1*n向量来处理的,与1*n的区别是需要将一维数组最外层再加层[]就可以了显示获得shape=(1,n)了
numpy.zeros() 返回的默认的都是(n,)的一维数组 ,实际是1*n的数组
err.shape= (20,) X.shape= (20, 1) np.matmul(error, X).shape= (1,)
2 向量+常量b是按照每个向量都加上该常量b来处理的
y_pred = np.matmul(X, W) + b
X.shape= (20, 1)
W.shape= (1, )
np.matmul(X, W)= (20,)
Loading data... Performng gradient descent (default params)... typex= <class 'numpy.ndarray'> typeW= <class 'numpy.ndarray'> type(y)= <class 'numpy.ndarray'> typeB= <class 'int'> X.shape[0]= 100 X= [[-7.24070e-01] [-2.40724e+00] [ 2.64837e+00] [ 3.60920e-01] [ 6.73120e-01] [-4.54600e-01] [ 2.20168e+00] [ 1.15605e+00] [ 5.06940e-01] [-8.59520e-01] [-5.99700e-01] [ 1.46804e+00] [-1.05659e+00] [ 1.29177e+00] [-7.45650e-01] [ 1.50330e-01] [-1.49627e+00] [-7.20710e-01] [ 3.29240e-01] [-2.80530e-01] [-1.36115e+00] [ 7.46780e-01] [ 1.06210e-01] [ 3.25600e-02] [-9.82900e-01] [-1.15661e+00] [ 9.02400e-02] [-1.03816e+00] [-6.04000e-03] [ 1.62780e-01] [-6.98690e-01] [ 1.03857e+00] [-1.17830e-01] [-9.54090e-01] [-8.18390e-01] [-1.28802e+00] [ 6.28220e-01] [-2.29674e+00] [-8.56010e-01] [-1.75223e+00] [-1.19662e+00] [ 9.77810e-01] [-1.17110e+00] [ 1.58350e-01] [-5.89180e-01] [-1.79678e+00] [-9.57270e-01] [ 6.45560e-01] [ 2.46250e-01] [ 4.59170e-01] [ 1.21036e+00] [-6.01160e-01] [ 2.68510e-01] [ 4.95940e-01] [-2.67877e+00] [ 4.94020e-01] [ 1.18643e+00] [-1.77410e-01] [ 5.79380e-01] [-2.14926e+00] [ 2.27700e+00] [-1.05695e+00] [ 1.68288e+00] [-1.53513e+00] [ 9.90000e-04] [ 4.55200e-01] [-3.78550e-01] [ 1.35638e+00] [ 1.76300e-02] [ 2.21725e+00] [-4.44420e-01] [ 8.95830e-01] [ 1.30499e+00] [ 1.08830e-01] [ 1.79466e+00] [-7.33000e-03] [ 7.98620e-01] [-1.23530e-01] [-1.34999e+00] [-6.78250e-01] [-1.79010e-01] [ 1.25770e-01] [ 1.11943e+00] [-3.02296e+00] [ 6.49650e-01] [ 1.05994e+00] [ 5.33600e-01] [-7.35910e-01] [-9.56900e-02] [ 1.04694e+00] [ 4.65110e-01] [-7.54630e-01] [-9.41590e-01] [-9.31400e-02] [-9.86410e-01] [-9.21590e-01] [ 7.69530e-01] [ 3.28300e-02] [-1.07619e+00] [ 2.01740e-01]] w= [0.] type(batch) <class 'numpy.ndarray'> np.matmul(X, W) shi ge sha: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.29708128 -0.18945281 -0.26524356 -0.16798167 -0.01454159 -0.00090461 -0.02189445 0.0921613 -0.01180925 0.0303901 0.05739996 0.0715022 0.12067307 -0.09313008 -0.26524356 -0.0837039 -0.09202184 -0.13043986 -0.13043986 -0.1181382 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.0815292 0.09336268 -0.20663037 -0.11841423 0.20428543 0.09336268 0.16982942 -0.10499406 0.10066208 -0.06602943 -0.05559289 0.02784003 -0.00738117 -0.16578594 0.09964234 0.03910347 0.08075707 0.02071186 -0.18568556 -0.23317991] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.10558319 -0.30071126 -0.06218052 0.02215536 0.16174741 -0.13393534 -0.1145041 -0.01728356 -0.02504598 -0.18021185 0.06424425 0.05049771 -0.33680623 -0.06360484 0.06938888 0.05049771 0.06938888 -0.13174149 0.23545823 -0.10083732] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.14675307 0.02219645 -0.13373798 0.01525506 -0.14815621 0.02827857 0.18107171 -0.13826838 0.0904903 -0.10452025 0.12557148 -0.01731561 0.07479649 -0.14815621 -0.24561593 0.09106361 -0.09507257 0.0176296 0.03763794 0.00456404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.79445215e-01 -1.48460697e-02 -3.33452437e-01 -1.85652669e-01 -9.37091240e-04 -1.15685775e-01 -9.14098372e-02 1.53596081e-04 -1.48460697e-02 -6.89506770e-02 7.86505022e-02 -1.14174638e-01 -1.12337691e-01 1.79358332e-01 -1.14174638e-01 2.61094720e-01 -9.30419895e-02 -2.77729641e-02 1.79358332e-01 -1.26971209e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.1230817 -0.01772103 0.10123383 -0.17996557 -0.15613399 0.33112149 0.19626387 0.13472828 0.22078576 -0.14349029 -0.01772103 0.09708895 0.01597344 -0.15895991 -0.11067713 0.0944811 -0.16185351 0.02381504 0.15940959 0.06995018] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.62509877e-01 1.96586931e-04 -2.13701908e-01 -1.90087648e-01 -2.45296804e-02 5.25893869e-01 -1.43780504e-01 1.33663227e-01 2.07893658e-01 2.56510202e-01 1.79191966e-02 1.00664423e-01 -2.68071102e-01 -1.48065702e-01 2.07893658e-01 1.94166330e-01 1.58584096e-01 2.40344402e-01 3.34173954e-01 -1.83002575e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.15895533 -0.24368162 0.09543243 -0.02462755 0.02808345 -0.19716056 -0.30965565 -0.00159706 -0.34057097 -0.27937756 0.27461281 -0.11751103 0.15319639 0.28026335 -0.2745044 0.12036141 0.258547 -0.15578765 -0.27947275 0.31370912] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.12939871 -0.29550498 -0.29550498 0.04507322 -0.86046441 0.04633419 -0.30085342 0.51083741 -0.21480015 -0.02723749 -0.30085342 0.21904133 -0.20610146 0.02568618 -0.30085342 0.22732159 0.75383998 0.63112469 0.25499174 0.02568618] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.15803179 -0.32949007 0.19583921 0.23989072 0.3773136 0.0493635 0.01015015 -0.42084139 0.18061399 0.32376035 -0.2872934 -0.22941013 -0.18740362 -0.32949007 -0.0553052 -0.18740362 0.4576419 -0.26684971 -0.75042498 0.82559404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.16429693 -0.05953361 0.19268523 -0.80057921 -0.80057921 -0.25096837 0.16493547 0.17746011 -0.30649449 0.25592368 -0.24080498 0.10949581 -0.80057921 0.08929875 0.05266268 0.34818232 0.34539869 0.73739397 0.34539869 0.45109321] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.01041465 0.6769341 0.18634114 -0.56438333 -0.86631542 0.01228142 0.00664992 0.01238326 0.8363323 0.39980249 -0.22620295 0.0567035 0.29026172 0.49223375 0.21853837 0.8363323 -0.2558315 -0.35516163 0.99894797 -0.10581409] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.22397949 -0.49792892 0.12727917 0.40473107 -0.14634167 0.12727917 -0.10844863 -0.35627266 -0.17180601 -0.2323993 -0.04555128 -0.67738544 -0.33092043 0.04862077 -0.27991444 -0.23183489 -0.29172847 0.45865578 -0.0369923 0.17750813] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.03997731 -0.26298605 -0.03505778 0.36804648 -0.40507659 0.82870964 0.05658403 0.39406693 -0.80897881 -0.90608217 -0.27253906 0.24298799 0.55256845 0.00663591 0.23646123 0.39091647 0.19081159 -0.40507659 0.12392553 0.24298799] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.27434967 1.07125609 -0.11347337 0.0131704 0.20060594 -0.04766181 0.0131704 0.72593348 0.36236 -0.52099943 -0.92902303 -1.22277639 -0.43531496 -0.03870626 -0.43531496 -0.3975795 -0.38086975 -0.29152459 0.05087351 0.18412675] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.19103629 -0.31524416 -0.26126218 0.87710133 -0.46093851 0.07771033 -0.04538816 0.52247813 -0.04538816 0.49759033 -0.52431554 0.24866998 -0.10806027 0.44531093 0.84808803 -0.14581761 -0.57636381 -0.04538816 0.24866998 -0.00282352] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.24169333 0.44216306 0.31152568 0.06790507 -0.95810478 0.50491205 0.20688562 0.48225617 0.03764439 0.03764439 0.01369532 -0.56781539 0.06790507 0.32101604 -0.43307734 -0.64039264 0.32101604 0.00735451 0.94987007 0.26930089] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.13640584 0.14953103 -0.33906323 -0.63601233 0.10202265 -0.61991243 0.06560523 -0.31264713 0.0674406 -0.39010563 -0.74441528 -0.18834314 -0.44587111 0.26027481 0.01348978 -0.39010563 -0.74441528 0.11124509 -0.40867478 -0.44587111] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.85631074 0.30659706 0.67049506 0.25883433 0.4137883 0.21259755 -0.04694597 -0.15082234 -0.28714614 -0.39160819 -0.39300665 0.19759301 -0.0706839 -0.15082234 -0.71587523 0.01297259 0.4137883 -0.38139666 0.42230256 0.04231632] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.17931287 0.03554744 -0.94826256 0.41241172 -0.58941228 -0.04641572 0.19969435 -0.46132096 0.25429969 -0.33720037 -0.37091214 0.38517996 -0.45561305 -0.42393392 0.18087674 -0.38718502 0.01282607 -0.41621306 -0.8466388 0.86728814] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 4.99166150e-01 -5.41908065e-01 -2.47885432e-01 -1.18027258e-01 4.16522247e-04 4.86384387e-01 6.32482720e-02 5.43485800e-01 -3.08394754e-03 2.24501284e-01 -3.13716983e-01 -3.17495134e-01 1.38125509e-02 1.91516087e-01 4.36957080e-01 2.24501284e-01 7.08037332e-01 4.45948071e-01 5.09234208e-01 -4.13535067e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.07580145 0.71903918 -0.45159991 0.31907449 -0.07580145 -0.07580145 0.45287744 0.10521451 -0.25685586 0.417786 0.94735788 -0.31859168 0.21189882 -0.91830799 0.32879482 0.49394208 0.11472548 -0.91830799 0.287602 0.57953649] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.2075635 -0.42078408 -0.04252632 -0.43707503 -0.49137211 -1.22308595 -0.80004177 -0.27448058 0.06863841 -0.48242304 0.07432289 0.15032601 -0.6831743 0.7683776 0.15032601 -0.34455267 -0.08100273 -1.22308595 0.478017 0.478017 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.4766563 0.04789788 0.22861644 0.66204694 -0.32653629 -0.08000718 -0.32653629 -0.43026918 -0.41561254 0.29113037 -0.04315364 -0.32653629 0.06779483 1.02686635 -0.69230274 0.75893406 0.00795066 0.33677789 -0.42463201 -0.17071597] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) Plotting the results... Regression lines start from the lightest line, with the darkest, black line as the last line. Do you see it getting closer to the data over each iteration?
来源:CSDN
作者:腾云鹏A
链接:https://blog.csdn.net/studyvcmfc/article/details/104803774