I\'m using the following custom training code in TensorFlow 2:
def parse_function(filename, filename2):
image = read_image(fn)
def ret1(): return ima
I fixed it by enabling eager execution after importing tensorflow:
import tensorflow as tf
tf.enable_eager_execution()
Reference: Tensorflow
I fixed this by changing the train function to the following:
def train(model, dataset, optimizer):
for step, (x1, x2, y) in enumerate(dataset):
with tf.GradientTape() as tape:
left, right = model([x1, x2])
loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
The two changes are removing the @tf.function and fixing the enumeration.
In case you are using Jupyter notebook after
import tensorflow as tf
tf.enable_eager_execution()
You need to restart the kernel and it works