Pytorch, INPUT (normal tensor) and WEIGHT (cuda tensor) mismatch

孤街浪徒 提交于 2021-01-29 05:26:23

问题


DISCLAIMER I know, this question has already asked multiple times, but i tried their solutions, none of them worked for me, so after all those effort, i can't find anything else and eventually i have to ask again.

I'm doing image classification with cnns (PYTORCH), i wan't to train it on GPU (nvidia gpu, compatible with cuda/cuda installed), i successfully managed to put net on it, but the problem is with data.

if torch.cuda.is_available():
    device = torch.device("cuda:0") 
    print("Running on the GPU")
    print("Available GPU", torch.cuda.device_count())

Running on the GPU
Available GPU 1
net = Net()
net.to(device)
for epoch in range(2):
    running_loss=0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device) # putting data on gpu
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
dataiter = iter(testloader)
images, labels = dataiter.next()

print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

out = net(images)

ERROR
----------------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-81-76c52eabb174> in <module>
----> 1 out = net(images)

~/anaconda3/envs/home/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-57-c74f8361a10b> in forward(self, x)
     11 
     12     def forward(self, x):
---> 13         x = self.pool(F.relu(self.conv1(x)))
     14         x = self.pool(F.relu(self.conv2(x)))
     15         x = x.view(-1, 16*5*5)

~/anaconda3/envs/home/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/anaconda3/envs/home/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    351 
    352     def forward(self, input):
--> 353         return self._conv_forward(input, self.weight)
    354 
    355 class Conv3d(_ConvNd):

~/anaconda3/envs/home/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    347                             weight, self.bias, self.stride,
    348                             _pair(0), self.dilation, self.groups)
--> 349         return F.conv2d(input, weight, self.bias, self.stride,
    350                         self.padding, self.dilation, self.groups)
    351 

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
inputs.is_cuda

True

and same for labels.

What I've tried:

https://stackoverflow.com/q/59013109/13287412 
https://github.com/sksq96/pytorch-summary/issues/57
https://blog.csdn.net/qq_27261889/article/details/86575033
https://blog.csdn.net/public669/article/details/96510293

but nothing worked so far...


回答1:


Your images tensor is located on the CPU while your net is located on the GPU. Even when evaluating you want to make sure that your input tensors and model are located on the same device otherwise you will get tensor data type errors.

out = net(images.to(device))


来源:https://stackoverflow.com/questions/63005606/pytorch-input-normal-tensor-and-weight-cuda-tensor-mismatch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!