TensorFlow 2.0 dataset.__iter__() is only supported when eager execution is enabled

前端 未结 3 623
故里飘歌
故里飘歌 2021-01-18 08:10

I\'m using the following custom training code in TensorFlow 2:

def parse_function(filename, filename2):
    image = read_image(fn)
    def ret1(): return ima         


        
相关标签:
3条回答
  • 2021-01-18 08:30

    I fixed it by enabling eager execution after importing tensorflow:

    import tensorflow as tf
    
    tf.enable_eager_execution()
    

    Reference: Tensorflow

    0 讨论(0)
  • 2021-01-18 08:39

    I fixed this by changing the train function to the following:

    def train(model, dataset, optimizer):
        for step, (x1, x2, y) in enumerate(dataset):
            with tf.GradientTape() as tape:
                left, right = model([x1, x2])
                loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

    The two changes are removing the @tf.function and fixing the enumeration.

    0 讨论(0)
  • 2021-01-18 08:39

    In case you are using Jupyter notebook after

    import tensorflow as tf
    
    tf.enable_eager_execution()
    

    You need to restart the kernel and it works

    0 讨论(0)
提交回复
热议问题