Resize PyTorch Tensor

前端 未结 3 1004
醉梦人生
醉梦人生 2021-01-18 10:27

I am currently using the tensor.resize() function to resize a tensor to a new shape t = t.resize(1, 2, 3).

This gives me a deprecation warning:

相关标签:
3条回答
  • 2021-01-18 11:03

    Simply use t = t.contiguous().view(1, 2, 3) if you don't really want to change its data.

    If not the case, the in-place resize_ operation will break the grad computation graph of t.
    If it doesn't matter to you, just use t = t.data.resize_(1,2,3).

    0 讨论(0)
  • 2021-01-18 11:14

    You can instead choose to go with tensor.reshape(new_shape) or torch.reshape(tensor, new_shape) as in:

    # a `Variable` tensor
    In [15]: ten = torch.randn(6, requires_grad=True)
    
    # this would throw RuntimeError error
    In [16]: ten.resize_(2, 3)
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-16-094491c46baa> in <module>()
    ----> 1 ten.resize_(2, 3)
    
    RuntimeError: cannot resize variables that require grad
    

    The above RuntimeError can be resolved or avoided by using tensor.reshape(new_shape)

    In [17]: ten.reshape(2, 3)
    Out[17]: 
    tensor([[-0.2185, -0.6335, -0.0041],
            [-1.0147, -1.6359,  0.6965]])
    
    # yet another way of changing tensor shape
    In [18]: torch.reshape(ten, (2, 3))
    Out[18]: 
    tensor([[-0.2185, -0.6335, -0.0041],
            [-1.0147, -1.6359,  0.6965]])
    
    0 讨论(0)
  • 2021-01-18 11:17

    Please can you try something like:

    import torch
    x = torch.tensor([[1, 2], [3, 4], [5, 6]])
    print(":::",x.resize_(2, 2))
    print("::::",x.resize_(3, 3))
    
    0 讨论(0)
提交回复
热议问题