Tensorflow 2.0: custom keras metric caused tf.function retracing warning

前端 未结 3 1618
借酒劲吻你
借酒劲吻你 2021-01-21 01:00

When I use the following custom metric (keras-style):

from sklearn.metrics import classification_report, f1_score
from ten         


        
3条回答
  •  小鲜肉
    小鲜肉 (楼主)
    2021-01-21 01:53

    This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value (Python or np objects or variables).

    In the general case, the fix is to use @tf.function(experimental_relax_shapes=True) before the definition of the custom function that you pass to Keras or TF somewhere. This tries to detect and avoid unnecessary retracing, but is not guaranteed to solve the issue.

    In your case, i guess the Predictor class is a custom class, so place @tf.function(experimental_relax_shapes=True) before the definition of Predictor.predict().

提交回复
热议问题