问题
I followed this tutorial to create a simple image classification script:
https://blog.hyperiondev.com/index.php/2019/02/18/machine-learning/
train_data = scipy.io.loadmat('extra_32x32.mat')
# extract the images and labels from the dictionary object
X = train_data['X']
y = train_data['y']
X = X.reshape(X.shape[0]*X.shape[1]*X.shape[2],X.shape[3]).T
y = y.reshape(y.shape[0],)
X, y = shuffle(X, y, random_state=42)
....
clf = RandomForestClassifier()
print(clf)
start_time = time.time()
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_split=1e-07, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=10, n_jobs=1, oob_score=False, random_state=None,
verbose=0, warm_start=False)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test,preds))
It gave me an accuracy of approximately 0.7.
Is there someway to visualize or show where/when/if the model is overfitting? I believe this can be shown by training the model until we see that the accuracy of training is increasing and the validation data is decreasing. But how can I do so in the code?
回答1:
There are multiple ways you can test overfitting and underfitting. If you want to look specifically at train and test scores and compare them you can do this with sklearns cross_validate[https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html#sklearn.model_selection.cross_validate]. If you read the documentation it will return you a dictionary with train scores (if supplied as train_score=True) and test scores in metrics that you supply.
sample code
model = RandomForestClassifier(n_estimators=1000, random_state=1, criterion='entropy', bootstrap=True, oob_score=True, verbose=1) cv_dict = cross_validate(model, X, y, return_train_score=True)
You can also simply create a hold out test set with train test split and compare your training and test scores using the test data set.
来源:https://stackoverflow.com/questions/64525145/show-overfitting-with-sklearn-random-forest