Pytorch reshape tensor dimension

前端 未结 10 545
忘掉有多难
忘掉有多难 2021-02-03 17:56

For example, I have 1D vector with dimension (5). I would like to reshape it into 2D matrix (1,5).

Here is how I do it with numpy

>>> import num         


        
相关标签:
10条回答
  • 2021-02-03 18:04

    torch.reshape() is made to dupe the numpy reshape method.

    It came after the view() and torch.resize_() and it is inside the dir(torch) package.

    import torch
    x=torch.arange(24)
    print(x, x.shape)
    x_view = x.view(1,2,3,4) # works on is_contiguous() tensor
    print(x_view.shape)
    x_reshaped = x.reshape(1,2,3,4) # works on any tensor
    print(x_reshaped.shape)
    x_reshaped2 = torch.reshape(x_reshaped, (-1,)) # part of torch package, while view() and resize_() are not
    print(x_reshaped2.shape)
    

    Out:

    tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
            18, 19, 20, 21, 22, 23]) torch.Size([24])
    torch.Size([1, 2, 3, 4])
    torch.Size([1, 2, 3, 4])
    torch.Size([24])
    

    But did you know it can also work as a replacement for squeeze() and unsqueeze()

    x = torch.tensor([1, 2, 3, 4])
    print(x.shape)
    x1 = torch.unsqueeze(x, 0)
    print(x1.shape)
    x2 = torch.unsqueeze(x1, 1)
    print(x2.shape)
    x3=x.reshape(1,1,4)
    print(x3.shape)
    x4=x.reshape(4)
    print(x4.shape)
    x5=x3.squeeze()
    print(x5.shape)
    
    

    Out:

    torch.Size([4])
    torch.Size([1, 4])
    torch.Size([1, 1, 4])
    torch.Size([1, 1, 4])
    torch.Size([4])
    torch.Size([4])
    
    0 讨论(0)
  • 2021-02-03 18:07
    import torch
    t = torch.ones((2, 3, 4))
    t.size()
    
    >>torch.Size([2, 3, 4])
    
    a = t.view(-1,t.size()[1]*t.size()[2])
    a.size()
    
    >>torch.Size([2, 12])
    
    0 讨论(0)
  • 2021-02-03 18:15

    Use torch.unsqueeze(input, dim, out=None)

    >>> import torch
    >>> a = torch.Tensor([1,2,3,4,5])
    >>> a
    
     1
     2
     3
     4
     5
    [torch.FloatTensor of size 5]
    
    >>> a = a.unsqueeze(0)
    >>> a
    
     1  2  3  4  5
    [torch.FloatTensor of size 1x5]
    
    0 讨论(0)
  • import torch
    >>>a = torch.Tensor([1,2,3,4,5])
    >>>a.size()
    torch.Size([5])
    #use view to reshape
    
    >>>b = a.view(1,a.shape[0])
    >>>b
    tensor([[1., 2., 3., 4., 5.]])
    >>>b.size()
    torch.Size([1, 5])
    >>>b.type()
    'torch.FloatTensor'
    
    0 讨论(0)
  • 2021-02-03 18:23

    For in-place modification of the shape of the tensor, you should use tensor.resize_():

    In [23]: a = torch.Tensor([1, 2, 3, 4, 5])
    
    In [24]: a.shape
    Out[24]: torch.Size([5])
    
    
    # tensor.resize_((`new_shape`))    
    In [25]: a.resize_((1,5))
    Out[25]: 
    
     1  2  3  4  5
    [torch.FloatTensor of size 1x5]
    
    In [26]: a.shape
    Out[26]: torch.Size([1, 5])
    

    In PyTorch, if there's an underscore at the end of an operation (like tensor.resize_()) then that operation does in-place modification to the original tensor.


    Also, you can simply use np.newaxis in a torch Tensor to increase the dimension. Here is an example:

    In [34]: list_ = range(5)
    In [35]: a = torch.Tensor(list_)
    In [36]: a.shape
    Out[36]: torch.Size([5])
    
    In [37]: new_a = a[np.newaxis, :]
    In [38]: new_a.shape
    Out[38]: torch.Size([1, 5])
    
    0 讨论(0)
  • 2021-02-03 18:23

    Assume the following code:

    import torch
    import numpy as np
    a = torch.tensor([1, 2, 3, 4, 5])
    

    The following three calls have the exact same effect:

    res_1 = a.unsqueeze(0)
    res_2 = a.view(1, 5)
    res_3 = a[np.newaxis,:]
    res_1.shape == res_2.shape == res_3.shape == (1,5)  # Returns true
    

    Notice that for any of the resulting tensors, if you modify the data in them, you are also modifying the data in a, because they don't have a copy of the data, but reference the original data in a.

    res_1[0,0] = 2
    a[0] == res_1[0,0] == 2  # Returns true
    

    The other way of doing it would be using the resize_ in place operation:

    a.shape == res_1.shape  # Returns false
    a.reshape_((1, 5))
    a.shape == res_1.shape # Returns true
    

    Be careful of using resize_ or other in-place operation with autograd. See the following discussion: https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

    0 讨论(0)
提交回复
热议问题