Tensorflow 2.0: custom keras metric caused tf.function retracing warning

假如想象 提交于 2021-01-20 07:14:22

问题


When I use the following custom metric (keras-style):

from sklearn.metrics import classification_report, f1_score
from tensorflow.keras.callbacks import Callback

class Metrics(Callback):
    def __init__(self, dev_data, classifier, dataloader):
        self.best_f1_score = 0.0
        self.dev_data = dev_data
        self.classifier = classifier
        self.predictor = Predictor(classifier, dataloader)
        self.dataloader = dataloader

    def on_epoch_end(self, epoch, logs=None):
        print("start to evaluate....")
        _, preds = self.predictor(self.dev_data)
        y_trues, y_preds = [self.dataloader.label_vector(v["label"]) for v in self.dev_data], preds
        f1 = f1_score(y_trues, y_preds, average="weighted")
        print(classification_report(y_trues, y_preds,
                                    target_names=self.dataloader.vocab.labels))
        if f1 > self.best_f1_score:
            self.best_f1_score = f1
            self.classifier.save_model()
            print("best metrics, save model...")

I obtained the following warning:

W1106 10:49:14.171694 4745115072 def_function.py:474] 6 out of the last 11 calls to .distributed_function at 0x14a3f9d90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.


回答1:


This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value (Python or np objects or variables).

In the general case, the fix is to use @tf.function(experimental_relax_shapes=True) before the definition of the custom function that you pass to Keras or TF somewhere. This tries to detect and avoid unnecessary retracing, but is not guaranteed to solve the issue.

In your case, i guess the Predictor class is a custom class, so place @tf.function(experimental_relax_shapes=True) before the definition of Predictor.predict().




回答2:


Add this line after importing tensorflow:

tf.compat.v1.disable_eager_execution()



回答3:


then using @tf.function(experimental_relax_shapes=True) will probably solve your problem



来源:https://stackoverflow.com/questions/58814130/tensorflow-2-0-custom-keras-metric-caused-tf-function-retracing-warning

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!