Tensorflow 2.0: custom keras metric caused tf.function retracing warning

前端 未结 3 1616
借酒劲吻你
借酒劲吻你 2021-01-21 01:00

When I use the following custom metric (keras-style):

from sklearn.metrics import classification_report, f1_score
from ten         


        
相关标签:
3条回答
  • 2021-01-21 01:40

    Add this line after importing tensorflow:

    tf.compat.v1.disable_eager_execution()
    
    0 讨论(0)
  • 2021-01-21 01:51

    then using @tf.function(experimental_relax_shapes=True) will probably solve your problem

    0 讨论(0)
  • 2021-01-21 01:53

    This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value (Python or np objects or variables).

    In the general case, the fix is to use @tf.function(experimental_relax_shapes=True) before the definition of the custom function that you pass to Keras or TF somewhere. This tries to detect and avoid unnecessary retracing, but is not guaranteed to solve the issue.

    In your case, i guess the Predictor class is a custom class, so place @tf.function(experimental_relax_shapes=True) before the definition of Predictor.predict().

    0 讨论(0)
提交回复
热议问题