Keras custom loss function: Accessing current input pattern

前端 未结 1 1252
孤独总比滥情好
孤独总比滥情好 2020-11-30 02:32

In Keras (with Tensorflow backend), is the current input pattern available to my custom loss function?

The current input pattern is defined as the input vector used

相关标签:
1条回答
  • 2020-11-30 03:02

    You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

    def custom_loss_wrapper(input_tensor):
        def custom_loss(y_true, y_pred):
            return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
        return custom_loss
    
    input_tensor = Input(shape=(10,))
    hidden = Dense(100, activation='relu')(input_tensor)
    out = Dense(1, activation='sigmoid')(hidden)
    model = Model(input_tensor, out)
    model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')
    

    You can verify that input_tensor and the loss value (mostly, the K.mean(input_tensor) part) will change as different X is passed to the model.

    X = np.random.rand(1000, 10)
    y = np.random.randint(2, size=1000)
    model.test_on_batch(X, y)  # => 1.1974642
    
    X *= 1000
    model.test_on_batch(X, y)  # => 511.15466
    
    0 讨论(0)
提交回复
热议问题