I\'m new to Machine Learning and currently got stuck with this. First I use linear regression to fit the training set but get very large RMSE. Then I tried using polynomial regr
You should provide the data for X/Y next time, or something dummy, it'll be faster and provide you with a specific solution. For now I've created a dummy equation of the form y = X**4 + X**3 + X + 1
.
There are many ways you can improve on this, but a quick iteration to find the best degree is to simply fit your data on each degree and pick the degree with the best performance (e.g., lowest RMSE).
You can also play with how you decide to hold out your train/test/validation data.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
X = np.arange(100).reshape(100, 1)
y = X**4 + X**3 + X + 1
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
rmses = []
degrees = np.arange(1, 10)
min_rmse, min_deg = 1e10, 0
for deg in degrees:
# Train features
poly_features = PolynomialFeatures(degree=deg, include_bias=False)
x_poly_train = poly_features.fit_transform(x_train)
# Linear regression
poly_reg = LinearRegression()
poly_reg.fit(x_poly_train, y_train)
# Compare with test data
x_poly_test = poly_features.fit_transform(x_test)
poly_predict = poly_reg.predict(x_poly_test)
poly_mse = mean_squared_error(y_test, poly_predict)
poly_rmse = np.sqrt(poly_mse)
rmses.append(poly_rmse)
# Cross-validation of degree
if min_rmse > poly_rmse:
min_rmse = poly_rmse
min_deg = deg
# Plot and present results
print('Best degree {} with RMSE {}'.format(min_deg, min_rmse))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(degrees, rmses)
ax.set_yscale('log')
ax.set_xlabel('Degree')
ax.set_ylabel('RMSE')
This will print:
Best degree 4 with RMSE 1.27689038706e-08
Alternatively, you could also build a new class that carries out Polynomial fitting, and pass that to GridSearchCV with a set of parameters.