Tensorflow 2.0 turn off tf.function retracing for prediction

|▌冷眼眸甩不掉的悲伤 提交于 2021-01-29 08:11:21

问题


I am trying to generate prediction intervals for a simple RNN using dropout. I'm using the functional API with training=True to enable dropout during testing.

To try different dropout levels, I defined a small function to edit the model configs:

from keras.models import Model, Sequential

def dropout_model(model, dropout):
    conf = model.get_config()
    for layer in conf['layers']:
        if layer["class_name"]=="Dropout":
            layer["config"]["rate"] = dropout
        elif "dropout" in layer["config"].keys():
            layer["config"]["dropout"] = dropout
            layer["config"]["recurrent_dropout"] = 0
            
    model_new = Model.from_config(conf)
    model_new.set_weights(model.get_weights())
    
    return model_new

However, when applying different dropout levels multiple times, e.g. using a loop like

for i in range(1, 8):
  dropout = i/10
  model_new = dropout_model(model, dropout)
  print(model_new.predict(input_data))

I get this warning message:

WARNING:tensorflow:9 out of the last 28 calls to <function Model.make_predict_function..predict_function at 0x7f5a8008cd90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

I already tried using @tf.function, @tf.function(experimental_relax_shapes=True) and tf.compat.v1.disable_eager_execution() suggested for related questions without success.

来源:https://stackoverflow.com/questions/65021750/tensorflow-2-0-turn-off-tf-function-retracing-for-prediction

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!