问题
TLDR: How to create a variable that holds the confusion matrix used for computing custom metrics, accumulating the values across all of the evaluation steps?
I have implemented custom metrics to use in the tf.estimator.train_and_evaluation
pipeline, with a confusion matrix as the crux for them all. My goal is to make this confusion matrix persist over multiple evaluation steps in order to track the learning progress.
Using get_variable
in the variable scope did not work, since it does not save the variable to the checkpoint (or so it seems).
This does not work:
@property
def confusion_matrix(self):
with tf.variable_scope(
f"{self.name}-{self.metric_type}", reuse=tf.AUTO_REUSE
):
confusion_matrix = tf.get_variable(
name="confusion-matrix",
initializer=tf.zeros(
[self.n_classes, self.n_classes],
dtype=tf.float32,
name=f"{self.name}/{self.metric_type}-confusion-matrix",
),
aggregation=tf.VariableAggregation.SUM,
)
return confusion_matrix
Just saving the matrix as a class attribute works, but it obviously does not persist over multple steps:
self.confusion_matrix = tf.zeros(
[self.n_classes, self.n_classes],
dtype=tf.float32,
name=f"{self.name}/{self.metric_type}-confusion-matrix",
)
You can look at the full example here.
I expect to have this confusion matrix persist from end to finish during evaluation, but I do not need to have it in the final SavedModel. Could you please tell me how I can achieve this? Do I need to just save the matrix to an external file, or there is a better way?
回答1:
You can define a custom metric:
def confusion_matrix(labels, predictions):
matrix = ... # confusion matrix calculation
mean, update_op = tf.metrics.mean_tensor(matrix)
# do something with mean if needed
return {'confusion_matrix': (mean, udpate_op)}
and then add it to your estimator
:
estimator = tf.estimator.add_metrics(estimator, confusion_matrix)
if you need sum instead of meen you can take insight from tf.metrics.mean_tensor
implementation
来源:https://stackoverflow.com/questions/57401924/how-to-create-a-variable-that-persists-over-tf-estimator-train-and-evaluate-eval