Sklearn - How to predict probability for all target labels

后端 未结 3 1209
孤城傲影
孤城傲影 2021-01-05 07:47

I have a data set with a target variable that can have 7 different labels. Each sample in my training set has only one label for the target variable.

For each sampl

相关标签:
3条回答
  • 2021-01-05 08:22

    You can do that by simply removing the OneVsRestClassifer and using predict_proba method of the DecisionTreeClassifier. You can do the following:

    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    pred = clf.predict_proba(X_test)
    

    This will give you a probability for each of your 7 possible classes.

    Hope that helps!

    0 讨论(0)
  • 2021-01-05 08:36

    You can try using scikit-multilearn - an extension of sklearn that handles multilabel classification. If your labels are not overly correlated you can train one classifier per label and get all predictions - try (after pip install scikit-multilearn):

    from skmultilearn.problem_transform import BinaryRelevance    
    classifier = BinaryRelevance(classifier = DecisionTreeClassifier())
    
    # train
    classifier.fit(X_train, y_train)
    
    # predict
    predictions = classifier.predict(X_test)
    

    Predictions will contain a sparse matrix of size (n_samples, n_labels) in your case - n_labels = 7, each column contains prediction per label for all samples.

    In case your labels are correlated you might need more sophisticated methods for multi-label classification.

    Disclaimer: I'm the author of scikit-multilearn, feel free to ask more questions.

    0 讨论(0)
  • 2021-01-05 08:36

    If you insist on using the OneVsRestClassifer, then you could also call predict_proba(X_test) as it is supported by OneVsRestClassifer as well.

    For eg:

    from sklearn.multiclass import OneVsRestClassifier
    clf = OneVsRestClassifier(DecisionTreeClassifier())
    clf.fit(X_train, y_train)
    pred = clf.predict_proba(X_test)
    

    The order of the labels for which you get the result can be found in:

    clf.classes_
    
    0 讨论(0)
提交回复
热议问题