Applying callbacks in a custom training loop in Tensorflow 2.0

前端 未结 3 461
伪装坚强ぢ
伪装坚强ぢ 2021-01-13 05:00

I\'m writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we

相关标签:
3条回答
  • 2021-01-13 05:21

    The simplest way would be to check if the loss has changed over your expected period and break or manipulate the training process if not. Here is one way you could implement a custom early stopping callback :

    def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
        #No early stopping for 2*patience epochs 
        if len(LossList)//patience < 2 :
            return False
        #Mean loss for last patience epochs and second-last patience epochs
        mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
        mean_recent = np.mean(LossList[::-1][:patience]) #last
        #you can use relative or absolute change
        delta_abs = np.abs(mean_recent - mean_previous) #abs change
        delta_abs = np.abs(delta_abs / mean_previous)  # relative change
        if delta_abs < min_delta :
            print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
            print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
            return True
        else:
            return False
    

    This Callback_EarlyStopping checks your metrics/loss every epoch and returns True if the relative change is less than what you expected by computing moving average of losses after every patience number of epochs. You can then capture this True signal and break the training loop. To completely answer your question, within your sample training loop you can use this as:

    gen_loss_seq = []
    for epoch in range(epochs):
      #in your example, make sure your train_step returns gen_loss
      gen_loss = train_step(dataset) 
      #ideally, you can have a validation_step and get gen_valid_loss
      gen_loss_seq.append(gen_loss)  
      #check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
      stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
      if stopEarly:
        print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
        print("Terminating training ")
        break
           
    

    Of course, you can increase the complexity in numerous ways, for example, which loss or metrics you would like to track, your interest in the loss at a particular epoch or moving average of loss, your interest in relative or absolute change in value, etc. You can refer to Tensorflow 2.x implementation of tf.keras.callbacks.EarlyStopping here which is generally used in the popular tf.keras.Model.fit method.

    0 讨论(0)
  • 2021-01-13 05:23

    I think you would need to implement the functionality of the callback manually. It should not be too difficult. You could for instance have the "train_step" function return the losses and then implement functionality of callbacks such as early stopping in your "train" function. For callbacks such as learning rate schedule the function tf.keras.backend.set_value(generator_optimizer.lr,new_lr) would come in handy. Therefore the functionality of the callback would be implemented in your "train" function.

    0 讨论(0)
  • 2021-01-13 05:36

    A custom training loop is just a normal Python loop, so you can use if statements to break the loop whenever some condition is met. For instance:

    if len(loss_history) > patience:
        if loss_history.popleft()*delta < min(loss_history):
            print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
                  f'validation loss in the last {patience} epochs.')
            break
    

    If there is no improvement of delta% in the loss in the past patience epochs, the loop will be broken. Here, I'm using a collections.deque, which can easily be used as a rolling list that keeps in memory information only the last patience epochs.

    Here's a full implementation, with the documentation example from the Tensorflow documentation:

    patience = 3
    delta = 0.001
    
    loss_history = deque(maxlen=patience + 1)
    
    for epoch in range(1, 25 + 1):
        train_loss = tf.metrics.Mean()
        train_acc = tf.metrics.CategoricalAccuracy()
        test_loss = tf.metrics.Mean()
        test_acc = tf.metrics.CategoricalAccuracy()
    
        for x, y in train:
            loss_value, grads = get_grad(model, x, y)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            train_loss.update_state(loss_value)
            train_acc.update_state(y, model(x, training=True))
    
        for x, y in test:
            loss_value, _ = get_grad(model, x, y)
            test_loss.update_state(loss_value)
            test_acc.update_state(y, model(x, training=False))
    
        print(verbose.format(epoch,
                             train_loss.result(),
                             test_loss.result(),
                             train_acc.result(),
                             test_acc.result()))
    
        loss_history.append(test_loss.result())
    
        if len(loss_history) > patience:
            if loss_history.popleft()*delta < min(loss_history):
                print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
                      f'validation loss in the last {patience} epochs.')
                break
    
    Epoch  1 Loss: 0.191 TLoss: 0.282 Acc: 68.920% TAcc: 89.200%
    Epoch  2 Loss: 0.157 TLoss: 0.297 Acc: 70.880% TAcc: 90.000%
    Epoch  3 Loss: 0.133 TLoss: 0.318 Acc: 71.560% TAcc: 90.800%
    Epoch  4 Loss: 0.117 TLoss: 0.299 Acc: 71.960% TAcc: 90.800%
    
    Early stopping. No improvement of more than 0.10000% in validation loss in the last 3 epochs.
    
    0 讨论(0)
提交回复
热议问题