How can I print the intermediate variables in the loss function in TensorFlow and Keras?

前端 未结 3 1296
囚心锁ツ
囚心锁ツ 2021-02-19 10:13

I\'m writing a custom objective to train a Keras (with TensorFlow backend) model but I need to debug some intermediate computation. For simplicity, let\'s say I have:

         


        
3条回答
  •  谎友^
    谎友^ (楼主)
    2021-02-19 10:45

    In TensorFlow 2, you can now add IDE breakpoints in the TensorFlow Keras models/layers/losses, including when using the fit, evaluate, and predict methods. However, you must add model.run_eagerly = True after calling model.compile() for the values of the tensor to be available in the debugger at the breakpoint. For example,

    import tensorflow as tf
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.models import Model
    from tensorflow.keras.optimizers import Adam
    
    def custom_loss(y_pred, y_true):
        diff = y_pred - y_true
        return tf.keras.backend.square(diff)  # Breakpoint in IDE here. =====
    
    class SimpleModel(Model):
    
        def __init__(self):
            super().__init__()
            self.dense0 = Dense(2)
            self.dense1 = Dense(1)
    
        def call(self, inputs):
            z = self.dense0(inputs)
            z = self.dense1(z)
            return z
    
    x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
    y = tf.convert_to_tensor([0, 1], dtype=tf.float32)
    
    model0 = SimpleModel()
    model0.run_eagerly = True
    model0.compile(optimizer=Adam(), loss=custom_loss)
    y0 = model0.fit(x, y, epochs=1)  # Values of diff *not* shown at breakpoint. =====
    
    model1 = SimpleModel()
    model1.compile(optimizer=Adam(), loss=custom_loss)
    model1.run_eagerly = True
    y1 = model1.fit(x, y, epochs=1)  # Values of diff shown at breakpoint. =====
    

    This also works for debugging the outputs of intermediate network layers (for example, adding the breakpoint in the call of the SimpleModel).

    Note: this was tested in TensorFlow 2.0.0-rc0.

提交回复
热议问题