TensorFlow 2.0 dataset.__iter__() is only supported when eager execution is enabled

前端 未结 3 622
故里飘歌
故里飘歌 2021-01-18 08:10

I\'m using the following custom training code in TensorFlow 2:

def parse_function(filename, filename2):
    image = read_image(fn)
    def ret1(): return ima         


        
3条回答
  •  臣服心动
    2021-01-18 08:39

    I fixed this by changing the train function to the following:

    def train(model, dataset, optimizer):
        for step, (x1, x2, y) in enumerate(dataset):
            with tf.GradientTape() as tape:
                left, right = model([x1, x2])
                loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

    The two changes are removing the @tf.function and fixing the enumeration.

提交回复
热议问题