Keras, append to logs from callback

后端 未结 1 604
孤城傲影
孤城傲影 2020-12-28 18:04

I have a callback that computes a couple of additional metrics in on_epoch_end for validation data and every 10 epochs for test data.

I also have a

相关标签:
1条回答
  • 2020-12-28 18:47

    You can insert your additional metrics into the dictionary logs.

    from keras.callbacks import Callback
    
    class ComputeMetrics(Callback):
        def on_epoch_end(self, epoch, logs):
            logs['val_metric'] = epoch ** 2  # replace it with your metrics
            if (epoch + 1) % 10 == 0:
                logs['test_metric'] = epoch ** 3  # same
            else:
                logs['test_metric'] = np.nan
    

    Just remember to place this callback before CSVLogger in your fit call. Callbacks that appear later in the list would receive a modified version of logs. For example,

    model = Sequential([Dense(1, input_shape=(10,))])
    model.compile(loss='mse', optimizer='adam')
    model.fit(np.random.rand(100, 10),
              np.random.rand(100),
              epochs=30,
              validation_data=(np.random.rand(100, 10), np.random.rand(100)),
              callbacks=[ComputeMetrics(), CSVLogger('1.log')])
    

    Now if you take a look at the output log file, you'll see two additional columns test_metric and val_metric:

    epoch,loss,test_metric,val_loss,val_metric
    0,0.547923130989,nan,0.370979120433,0
    1,0.525437340736,nan,0.35585285902,1
    2,0.501358469725,nan,0.341958616376,4
    3,0.479624577463,nan,0.329370084703,9
    4,0.460121934414,nan,0.317930338383,16
    5,0.440655426979,nan,0.307486981452,25
    6,0.422990380526,nan,0.298160370588,36
    7,0.406809270382,nan,0.289906248748,49
    8,0.3912438941,nan,0.282540213466,64
    9,0.377326357365,729,0.276457450986,81
    10,0.364721306562,nan,0.271435074806,100
    11,0.353612961769,nan,0.266939682364,121
    12,0.343238875866,nan,0.263228923082,144
    13,0.333940329552,nan,0.260326927304,169
    14,0.325931007862,nan,0.25773427248,196
    15,0.317790198028,nan,0.255648627281,225
    16,0.310636150837,nan,0.25411529541,256
    17,0.304091459513,nan,0.252928718328,289
    18,0.298703012466,nan,0.252127869725,324
    19,0.292693507671,6859,0.251701972485,361
    20,0.287824733257,nan,0.251610517502,400
    21,0.283586999774,nan,0.251790778637,441
    22,0.27927801609,nan,0.252100949883,484
    23,0.276239238977,nan,0.252632959485,529
    24,0.273072380424,nan,0.253150621653,576
    25,0.270296501517,nan,0.253555388451,625
    26,0.268056542277,nan,0.254015884399,676
    27,0.266158599854,nan,0.254496408701,729
    28,0.264166412354,nan,0.254723013639,784
    29,0.262506003976,24389,0.255338237286,841
    
    0 讨论(0)
提交回复
热议问题