问题
I am using a DNNClassifier as my estimator and wanted to add some additional metrics to the estimator. the code I am using is basically the one from the tf.estimator.add_metrics documentation (https://www.tensorflow.org/api_docs/python/tf/estimator/add_metrics).
def my_auc(labels, predictions):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logits'])
return {'auc': auc_metric}
hidden_layers = len(training_data.__call__().element_spec[0])
final_layer = len(labels)
est = tf.estimator.DNNClassifier(feature_columns=features,
hidden_units=[hidden_layers, (hidden_layers / 2), (hidden_layers / 4),
final_layer],
n_classes=final_layer, label_vocabulary=labels)
est = tf.estimator.add_metrics(est, my_auc)
# Training
est.train(training_data, max_steps=100)
# Validation
result = est.evaluate(validation_data)
The model works fine without the add_metrics statement. But runs into an ValueError: "Shapes (None, 12) and (None,) are incompatible" when including it. The error occures in the line:
auc_metric.update_state(y_true=labels, y_pred=predictions['logits'])
The line is called by est.evaluate(validation_data).
It is not clear to me why this happens, but it seems like the y_true parameter is not filled correctly. Hence, the label column is not processed correctly to the function. This seems strange since the model works correctly without the additional metric. The training and validation data is created by the following function:
def get_dataset_from_tensor_slices(data_input, label_column, n_epochs=None, shuffle=True):
def get_dataset():
dataset = tf.data.Dataset.from_tensor_slices((dict(data_input), label_column))
if shuffle:
dataset = dataset.shuffle(len(label_column))
# For training, cycle through dataset as many times as need (n_epochs=None).
dataset = dataset.repeat(n_epochs)
# In memory training doesn't use batching.
dataset = dataset.batch(len(label_column))
return dataset
return get_dataset
Any help is appreciated. Thank you very much!
来源:https://stackoverflow.com/questions/60789274/tf-estimator-add-metrics-ends-in-shapes-none-12-and-none-are-incompatible