PyTorch custom loss function

后端 未结 3 1999
爱一瞬间的悲伤
爱一瞬间的悲伤 2021-02-04 07:46

How should a custom loss function be implemented ? Using below code is causing error :

import torch
import torch.nn as nn
import torchvision
import torchvision.         


        
3条回答
  •  挽巷
    挽巷 (楼主)
    2021-02-04 08:38

    If you use torch functions you should be fine

    import torch 
    
    def my_custom_loss(output, target):
        loss = torch.mean((output-target*2)**3)
        return loss
    
    # Forward pass to the Network
    # then, 
    loss.backward()
    
    

提交回复
热议问题