Why doesn't my simple pytorch network work on GPU device?

后端 未结 2 367
鱼传尺愫
鱼传尺愫 2020-12-20 23:45

I built a simple network from a tutorial and I got this error:

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.Flo

相关标签:
2条回答
  • 2020-12-21 00:03
    import torch
    import numpy as np
    
    x = torch.tensor(np.array(1), device='cuda:0')
    
    print(x.device)  # Prints `cpu`
    
    x = torch.tensor(1, device='cuda:0')
    
    print(x.device)  # Prints `cuda:0`
    

    Now the tensor resides on GPU

    0 讨论(0)
  • 2020-12-21 00:16

    TL;DR
    This is the fix

    inputs = inputs.to(device)  
    

    Why?!
    There is a slight difference between torch.nn.Module.to() and torch.Tensor.to(): while Module.to() is an in-place operator, Tensor.to() is not. Therefore

    net.to(device)
    

    Changes net itself and moves it to device. On the other hand

    inputs.to(device)
    

    does not change inputs, but rather returns a copy of inputs that resides on device. To use that "on device" copy, you need to assign it into a variable, hence

    inputs = inputs.to(device)
    
    0 讨论(0)
提交回复
热议问题