正确格式:
1.
data:3,3,3,2
label:3,3,2
2.data:3:2
label:2
3.data:2,1,6
label:2,6
Traceback (most recent call last):
torch.Size([3, 3, 3, 2]) torch.Size([3, 2])
File "F:/project/chedaoxian/Ultra-Fast-bar-Detection/utils/loss.py", line 91, in <module>
out_size, target.size()))
ValueError: Expected target size (3, 3, 2), got torch.Size([3, 2])
data维度 [2,6]
label维度[2],这样才可以,增加维度就报错。
如果data是[2,1,6]
label就需要是[2,6]
import torch
import torch.nn as nn
import torch.nn.functional as F
data = torch.randn(2,6)
target = torch.tensor([1,2])
print('data:', data.size(),target.size())
entropy_out = F.cross_entropy(data, target)
print('entropy_out:', entropy_out, target.size())
log_soft = F.log_softmax(data, dim=1)
print('log_soft:', log_soft, '\n')
nll_out = F.nll_loss(log_soft, target)
print('nll_out:', nll_out)
来源:oschina
链接:https://my.oschina.net/u/4302478/blog/4757730