torch nll_loss

不想你离开。 提交于 2020-11-28 13:29:22

 

正确格式:

1.

data:3,3,3,2

label:3,3,2

2.data:3:2

label:2

3.data:2,1,6

label:2,6

 

Traceback (most recent call last):
torch.Size([3, 3, 3, 2]) torch.Size([3, 2])
  File "F:/project/chedaoxian/Ultra-Fast-bar-Detection/utils/loss.py", line 91, in <module>
    out_size, target.size()))
ValueError: Expected target size (3, 3, 2), got torch.Size([3, 2])



 

data维度 [2,6]

label维度[2],这样才可以,增加维度就报错。

如果data是[2,1,6]

label就需要是[2,6]

  import torch
    import torch.nn as nn
    import torch.nn.functional as F

    data = torch.randn(2,6)

    target = torch.tensor([1,2])

    print('data:', data.size(),target.size())
    entropy_out = F.cross_entropy(data, target)
    print('entropy_out:', entropy_out, target.size())
    log_soft = F.log_softmax(data, dim=1)
    print('log_soft:', log_soft, '\n')
    nll_out = F.nll_loss(log_soft, target)


    print('nll_out:', nll_out)

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!