Applying callbacks in a custom training loop in Tensorflow 2.0

前端 未结 3 464
伪装坚强ぢ
伪装坚强ぢ 2021-01-13 05:00

I\'m writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we

3条回答
  •  情话喂你
    2021-01-13 05:21

    The simplest way would be to check if the loss has changed over your expected period and break or manipulate the training process if not. Here is one way you could implement a custom early stopping callback :

    def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
        #No early stopping for 2*patience epochs 
        if len(LossList)//patience < 2 :
            return False
        #Mean loss for last patience epochs and second-last patience epochs
        mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
        mean_recent = np.mean(LossList[::-1][:patience]) #last
        #you can use relative or absolute change
        delta_abs = np.abs(mean_recent - mean_previous) #abs change
        delta_abs = np.abs(delta_abs / mean_previous)  # relative change
        if delta_abs < min_delta :
            print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
            print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
            return True
        else:
            return False
    

    This Callback_EarlyStopping checks your metrics/loss every epoch and returns True if the relative change is less than what you expected by computing moving average of losses after every patience number of epochs. You can then capture this True signal and break the training loop. To completely answer your question, within your sample training loop you can use this as:

    gen_loss_seq = []
    for epoch in range(epochs):
      #in your example, make sure your train_step returns gen_loss
      gen_loss = train_step(dataset) 
      #ideally, you can have a validation_step and get gen_valid_loss
      gen_loss_seq.append(gen_loss)  
      #check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
      stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
      if stopEarly:
        print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
        print("Terminating training ")
        break
           
    

    Of course, you can increase the complexity in numerous ways, for example, which loss or metrics you would like to track, your interest in the loss at a particular epoch or moving average of loss, your interest in relative or absolute change in value, etc. You can refer to Tensorflow 2.x implementation of tf.keras.callbacks.EarlyStopping here which is generally used in the popular tf.keras.Model.fit method.

提交回复
热议问题