Multi label classification in pytorch

后端 未结 1 1634
傲寒
傲寒 2021-02-07 13:30

I have a multi-label classification problem. I have 11 classes, around 4k examples. Each example can have from 1 to 4-5 label. At the moment, i\'m training a classifier separate

相关标签:
1条回答
  • You are looking for torch.nn.BCELoss. Here's example code:

    import torch
    
    batch_size = 2
    num_classes = 11
    
    loss_fn = torch.nn.BCELoss()
    
    outputs_before_sigmoid = torch.randn(batch_size, num_classes)
    sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
    target_classes = torch.randint(0, 2, (batch_size, num_classes))  # randints in [0, 2).
    
    loss = loss_fn(sigmoid_outputs, target_classes)
    
    # alternatively, use BCE with logits, on outputs before sigmoid.
    loss_fn_2 = torch.nn.BCEWithLogitsLoss()
    loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
    assert loss == loss2
    
    0 讨论(0)
提交回复
热议问题