Scoring metrics from Keras scikit-learn wrapper in cross validation with one-hot encoded labels

后端 未结 3 976
刺人心
刺人心 2021-01-26 07:45

I am implementing a neural network and I would like to assess its performance with cross validation. Here is my current code:

def recall_m(y_true, y_pred):
    t         


        
3条回答
  •  无人及你
    2021-01-26 08:20

    For anybody still wanting to use cross_validate with one-hot encoded labels. This is a more scikit oriented way to go about it.

    X, y = get_data()
    # in my application I have words as labels, so y is a np.array with strings
    encoder = LabelEncoder()
    y_encoded = encoder.fit_transform(y)
    
    # build a version of the scoring metrics for multi-class and one-hot encoding predictions
    labels = sorted(set(np.unique(y_encoded)) - set(encoder.transform(['nan'])))
    
    # these functions compare y (one-hot encoded) to y_pred (integer encoded)
    # by making y integer encoded as well
    
    def f1_categorical(y, y_pred, **kwargs):
        return f1_score(y.argmax(1), y_pred, **kwargs)
    
    def precision_categorical(y, y_pred, **kwargs):
        return precision_score(y.argmax(1), y_pred, **kwargs)
    
    def recall_categorical(y, y_pred, **kwargs):
        return recall_score(y.argmax(1), y_pred, **kwargs)
    
    def accuracy_categorical(y, y_pred, **kwargs):
        return accuracy_score(y.argmax(1), y_pred, **kwargs)
    
    # Wrap the functions abobe with `make_scorer` 
    # (here I chose the micro average because it worked for my multi-class application)
    our_f1 = make_scorer(f1_categorical, labels=labels, average="micro")
    our_precision = make_scorer(precision_categorical, labels=labels, average="micro")
    our_recall = make_scorer(recall_categorical, labels=labels, average="micro")
    aur_accuracy = make_scorer(accuracy_categorical)
    scoring = {
        'accuracy':aur_accuracy,
        'f1':our_f1,
        'precision':our_precision,
        'recall':our_recall
    }
    
    # one-hot encoding
    y_categorical = tf.keras.utils.to_categorical(y_encoded)
    
    # keras wrapper
    estimator = tf.keras.wrappers.scikit_learn.KerasClassifier(
                    build_fn=model_with_one_hot_encoded_output,
                    epochs=1,
                    batch_size=32,
                    verbose=1)
    
    # cross validate as usual
    results = cross_validate(estimator, 
                             X_scaled, y_categorical, 
                             scoring=scoring)
    

提交回复
热议问题