PyTorch custom loss function

后端 未结 3 2003
爱一瞬间的悲伤
爱一瞬间的悲伤 2021-02-04 07:46

How should a custom loss function be implemented ? Using below code is causing error :

import torch
import torch.nn as nn
import torchvision
import torchvision.         


        
相关标签:
3条回答
  • 2021-02-04 08:29

    Solution

    Here are a few examples of custom loss functions that I came across in this Kaggle Notebook. It provides an implementation of the following custom loss functions in PyTorch as well as TensorFlow.

    Loss Function Reference for Keras & PyTorch

    I hope this will be helpful for anyone looking to see how to make your own custom loss functions.

    • Dice Loss
    • BCE-Dice Loss
    • Jaccard/Intersection over Union (IoU) Loss
    • Focal Loss
    • Tversky Loss
    • Focal Tversky Loss
    • Lovasz Hinge Loss
    • Combo Loss
    0 讨论(0)
  • 2021-02-04 08:38

    If you use torch functions you should be fine

    import torch 
    
    def my_custom_loss(output, target):
        loss = torch.mean((output-target*2)**3)
        return loss
    
    # Forward pass to the Network
    # then, 
    loss.backward()
    
    
    0 讨论(0)
  • 2021-02-04 08:46

    Your loss function is programmatically correct except for below:

        # the number of tokens is the sum of elements in mask
        num_tokens = int(torch.sum(mask).data[0])
    

    When you do torch.sum it returns a 0-dimensional tensor and hence the warning that it can't be indexed. To fix this do int(torch.sum(mask).item()) as suggested or int(torch.sum(mask)) will work too.

    Now, are you trying to emulate the CE loss using the custom loss? If yes, then you are missing the log_softmax

    To fix that add outputs = torch.nn.functional.log_softmax(outputs, dim=1) before statement 4. Note that in case of tutorial that you have attached, log_softmax is already done in the forward call. You can do that too.

    Also, I noticed that the learning rate is slow and even with CE loss, results are not consistent. Increasing the learning rate to 1e-3 works well for me in case of custom as well as CE loss.

    0 讨论(0)
提交回复
热议问题