scikit-learn cross validation custom splits for time series data

后端 未结 3 531
不知归路
不知归路 2021-01-31 12:12

I\'d like to use scikit-learn\'s GridSearchCV to determine some hyper parameters for a random forest model. My data is time dependent and looks something like

i         


        
3条回答
  •  悲哀的现实
    2021-01-31 12:44

    You just have to pass an iterable with the splits to GridSearchCV. This split should have the following format:

    [
     (split1_train_idxs, split1_test_idxs),
     (split2_train_idxs, split2_test_idxs),
     (split3_train_idxs, split3_test_idxs),
     ...
    ]
    

    To get the idxs you can do something like this:

    groups = df.groupby(df.date.dt.year).groups
    # {2012: [0, 1], 2013: [2], 2014: [3], 2015: [4, 5]}
    sorted_groups = [value for (key, value) in sorted(groups.items())] 
    # [[0, 1], [2], [3], [4, 5]]
    
    cv = [(sorted_groups[i] + sorted_groups[i+1], sorted_groups[i+2])
          for i in range(len(sorted_groups)-2)]
    

    This looks like this:

    [([0, 1, 2], [3]),  # idxs of first split as (train, test) tuple
     ([2, 3], [4, 5])]  # idxs of second split as (train, test) tuple
    

    Then you can do:

    GridSearchCV(estimator, param_grid, cv=cv, ...)
    

提交回复
热议问题