Resize PyTorch Tensor

前端 未结 3 1017
醉梦人生
醉梦人生 2021-01-18 10:27

I am currently using the tensor.resize() function to resize a tensor to a new shape t = t.resize(1, 2, 3).

This gives me a deprecation warning:

3条回答
  •  悲&欢浪女
    2021-01-18 11:14

    You can instead choose to go with tensor.reshape(new_shape) or torch.reshape(tensor, new_shape) as in:

    # a `Variable` tensor
    In [15]: ten = torch.randn(6, requires_grad=True)
    
    # this would throw RuntimeError error
    In [16]: ten.resize_(2, 3)
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
     in ()
    ----> 1 ten.resize_(2, 3)
    
    RuntimeError: cannot resize variables that require grad
    

    The above RuntimeError can be resolved or avoided by using tensor.reshape(new_shape)

    In [17]: ten.reshape(2, 3)
    Out[17]: 
    tensor([[-0.2185, -0.6335, -0.0041],
            [-1.0147, -1.6359,  0.6965]])
    
    # yet another way of changing tensor shape
    In [18]: torch.reshape(ten, (2, 3))
    Out[18]: 
    tensor([[-0.2185, -0.6335, -0.0041],
            [-1.0147, -1.6359,  0.6965]])
    

提交回复
热议问题