What's the best way to test whether an sklearn model has been fitted?

后端 未结 4 999
误落风尘
误落风尘 2021-02-07 02:30

What\'s the most elegant way to check whether an sklearn model has been fitted? i.e. whether its fit() function has been called after it was instantiated, or not. <

相关标签:
4条回答
  • 2021-02-07 03:05

    I do this for classifiers:

    def check_fitted(clf): 
        return hasattr(clf, "classes_")
    
    0 讨论(0)
  • 2021-02-07 03:08

    You can do something like:

    from sklearn.exceptions import NotFittedError
    
    for model in models:
        try:
            model.predict(some_test_data)
        except NotFittedError as e:
            print(repr(e))
    

    Ideally you would check the results of model.predict against expected results but if all you want to know if wether the model is fitted or not that should suffice.

    Update:

    Some commenters have suggested using check_is_fitted. I consider check_is_fitted an internal method. Most algorithms will call check_is_fitted inside their predict method which in turn might raise NotFittedError if needed. The problem with using check_is_fitted directly is that it is model specific, i.e. you need to know which members to check depending on your algorithm. For example:

    ╔════════════════╦════════════════════════════════════════════╗
    ║ Tree models    ║ check_is_fitted(self, 'tree_')             ║
    ║ Linear models  ║ check_is_fitted(self, 'coefs_')            ║
    ║ KMeans         ║ check_is_fitted(self, 'cluster_centers_')  ║
    ║ SVM            ║ check_is_fitted(self, 'support_')          ║
    ╚════════════════╩════════════════════════════════════════════╝
    

    and so on. So in general I would recommend calling model.predict() and letting the specific algorithm handle the best way to check whether it is already fitted or not.

    0 讨论(0)
  • 2021-02-07 03:12

    Cribbing directly from the scikit-learn source code for the check_is_fitted function (similar logic to @david-marx, but a little simpler):

    def is_fitted(model):
        '''
        Checks if a scikit-learn estimator/transformer has already been fit.
        
        
        Parameters
        ----------
        model: scikit-learn estimator (e.g. RandomForestClassifier) 
            or transformer (e.g. MinMaxScaler) object
            
        
        Returns
        -------
        Boolean that indicates if ``model`` has already been fit (True) or not (False).
        '''
        
        attrs = [v for v in vars(model)
                 if v.endswith("_") and not v.startswith("__")]
        
        return len(attrs) != 0
    
    0 讨论(0)
  • 2021-02-07 03:15

    This is sort of a greedy approach, but it should be fine for most if not all models. The only time this might not work is for models that set an attribute ending in an underscore prior to being fit, which I'm pretty sure would violate scikit-learn convention so this should be fine.

    import inspect
    
    def is_fitted(model):
            """Checks if model object has any attributes ending with an underscore"""
            return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )
    
    0 讨论(0)
提交回复
热议问题