sklearn: User defined cross validation for time series data

后端 未结 2 1225
执念已碎
执念已碎 2021-02-14 10:52

I\'m trying to solve a machine learning problem. I have a specific dataset with time-series element. For this problem I\'m using well-known python library - sklea

2条回答
  •  小鲜肉
    小鲜肉 (楼主)
    2021-02-14 11:46

    Meanwhile this was added to the library: http://scikit-learn.org/stable/modules/cross_validation.html#time-series-split

    Example from the doc:

    >>> from sklearn.model_selection import TimeSeriesSplit
    
    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
    >>> y = np.array([1, 2, 3, 4, 5, 6])
    >>> tscv = TimeSeriesSplit(n_splits=3)
    >>> print(tscv)  
    TimeSeriesSplit(n_splits=3)
    >>> for train, test in tscv.split(X):
    ...     print("%s %s" % (train, test))
    [0 1 2] [3]
    [0 1 2 3] [4]
    [0 1 2 3 4] [5]
    

提交回复
热议问题