Pytorch reshape tensor dimension

前端 未结 10 536
忘掉有多难
忘掉有多难 2021-02-03 17:56

For example, I have 1D vector with dimension (5). I would like to reshape it into 2D matrix (1,5).

Here is how I do it with numpy

>>> import num         


        
10条回答
  •  孤街浪徒
    2021-02-03 18:04

    torch.reshape() is made to dupe the numpy reshape method.

    It came after the view() and torch.resize_() and it is inside the dir(torch) package.

    import torch
    x=torch.arange(24)
    print(x, x.shape)
    x_view = x.view(1,2,3,4) # works on is_contiguous() tensor
    print(x_view.shape)
    x_reshaped = x.reshape(1,2,3,4) # works on any tensor
    print(x_reshaped.shape)
    x_reshaped2 = torch.reshape(x_reshaped, (-1,)) # part of torch package, while view() and resize_() are not
    print(x_reshaped2.shape)
    

    Out:

    tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
            18, 19, 20, 21, 22, 23]) torch.Size([24])
    torch.Size([1, 2, 3, 4])
    torch.Size([1, 2, 3, 4])
    torch.Size([24])
    

    But did you know it can also work as a replacement for squeeze() and unsqueeze()

    x = torch.tensor([1, 2, 3, 4])
    print(x.shape)
    x1 = torch.unsqueeze(x, 0)
    print(x1.shape)
    x2 = torch.unsqueeze(x1, 1)
    print(x2.shape)
    x3=x.reshape(1,1,4)
    print(x3.shape)
    x4=x.reshape(4)
    print(x4.shape)
    x5=x3.squeeze()
    print(x5.shape)
    
    

    Out:

    torch.Size([4])
    torch.Size([1, 4])
    torch.Size([1, 1, 4])
    torch.Size([1, 1, 4])
    torch.Size([4])
    torch.Size([4])
    

提交回复
热议问题