问题
I am trying to generate prediction intervals for a simple RNN using dropout. I'm using the functional API with training=True
to enable dropout during testing.
To try different dropout levels, I defined a small function to edit the model configs:
from keras.models import Model, Sequential
def dropout_model(model, dropout):
conf = model.get_config()
for layer in conf['layers']:
if layer["class_name"]=="Dropout":
layer["config"]["rate"] = dropout
elif "dropout" in layer["config"].keys():
layer["config"]["dropout"] = dropout
layer["config"]["recurrent_dropout"] = 0
model_new = Model.from_config(conf)
model_new.set_weights(model.get_weights())
return model_new
However, when applying different dropout levels multiple times, e.g. using a loop like
for i in range(1, 8):
dropout = i/10
model_new = dropout_model(model, dropout)
print(model_new.predict(input_data))
I get this warning message:
WARNING:tensorflow:9 out of the last 28 calls to <function Model.make_predict_function..predict_function at 0x7f5a8008cd90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
I already tried using @tf.function
, @tf.function(experimental_relax_shapes=True)
and tf.compat.v1.disable_eager_execution()
suggested for related questions without success.
来源:https://stackoverflow.com/questions/65021750/tensorflow-2-0-turn-off-tf-function-retracing-for-prediction