Retraining after Cross Validation with libsvm

前端 未结 2 511
一生所求
一生所求 2020-11-29 23:20

I know that Cross validation is used for selecting good parameters. After finding them, i need to re-train the whole data without the -v option.

But the problem i fa

相关标签:
2条回答
  • 2020-11-29 23:58

    If you use your entire dataset to determine your parameters, then train on that dataset, you are going to overfit your data. Ideally, you would divide the dataset, do the parameter search on a portion (with CV), then use the other portion to train and test with CV. Will you get better results if you use the whole dataset for both? Of course, but your model is likely to not generalize well. If you want determine true performance of your model, you need to do parameter selection separately.

    0 讨论(0)
  • 2020-11-30 00:04

    The -v option here is really meant to be used as a way to avoid the overfitting problem (instead of using the whole data for training, perform an N-fold cross-validation training on N-1 folds and testing on the remaining fold, one at-a-time, then report the average accuracy). Thus it only returns the cross-validation accuracy (assuming you have a classification problem, otherwise mean-squared error for regression) as a scalar number instead of an actual SVM model.

    If you want to perform model selection, you have to implement a grid search using cross-validation (similar to the grid.py helper python script), to find the best values of C and gamma.

    This shouldn't be hard to implement: create a grid of values using MESHGRID, iterate overall all pairs (C,gamma) training an SVM model with say 5-fold cross-validation, and choosing the values with the best CV-accuracy...

    Example:

    %# read some training data
    [labels,data] = libsvmread('./heart_scale');
    
    %# grid of parameters
    folds = 5;
    [C,gamma] = meshgrid(-5:2:15, -15:2:3);
    
    %# grid search, and cross-validation
    cv_acc = zeros(numel(C),1);
    for i=1:numel(C)
        cv_acc(i) = svmtrain(labels, data, ...
                        sprintf('-c %f -g %f -v %d', 2^C(i), 2^gamma(i), folds));
    end
    
    %# pair (C,gamma) with best accuracy
    [~,idx] = max(cv_acc);
    
    %# contour plot of paramter selection
    contour(C, gamma, reshape(cv_acc,size(C))), colorbar
    hold on
    plot(C(idx), gamma(idx), 'rx')
    text(C(idx), gamma(idx), sprintf('Acc = %.2f %%',cv_acc(idx)), ...
        'HorizontalAlign','left', 'VerticalAlign','top')
    hold off
    xlabel('log_2(C)'), ylabel('log_2(\gamma)'), title('Cross-Validation Accuracy')
    
    %# now you can train you model using best_C and best_gamma
    best_C = 2^C(idx);
    best_gamma = 2^gamma(idx);
    %# ...
    

    contour_plot

    0 讨论(0)
提交回复
热议问题