Add extra dimension to an axes

一个人想着一个人 提交于 2020-01-16 09:27:29

问题


I have a batch of segmentation masks of shape [5,1,100,100] (batch_size x dims x ht x wd) which I have to display in tensorboardX with an RGB image batch [5,3,100,100]. I want to add two dummy dimensions to the second axes of the segmentation mask to make it [5,3,100,100] so there will not be any dimension mismatch error when I pass it to torch.utils.make_grid. I have tried unsqueeze, expand and view but I am not able to do it. Any suggestions?


回答1:


You can use expand, repeat, or repeat_interleave:

import torch

x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)

print(x1_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape)  # torch.Size([5, 3, 100, 100])

Note that, as stated in the docs:

  • torch.expand():

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

  • torch.repeat():

Unlike expand(), this function copies the tensor’s data.




回答2:


Expand is a method that I keep telling myself don't read the docs, where it reads:

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor

Since there is nothing like view in PyTorch, at least I never saw them as objects so they are not created. The only thing there is: is stride.

And expand can also shrink.

t21 = torch.rand(2,1)
print(t)
print(t.shape)
print(t.stride())

t25 = t.expand(-1,5)
print(t25.shape)
print(t25)
print(t25.stride())

t123 = t.expand(1,-1,3)
print(t123.shape)
print(t123)
print(t123.stride())

# tensor([[0.1353],
#         [0.5809]])
# torch.Size([2, 1])
# (1, 1)
# torch.Size([2, 5])
# tensor([[0.1353, 0.1353, 0.1353, 0.1353, 0.1353],
#         [0.5809, 0.5809, 0.5809, 0.5809, 0.5809]])
# (1, 0)
# torch.Size([1, 2, 3])
# tensor([[[0.1353, 0.1353, 0.1353],
#          [0.5809, 0.5809, 0.5809]]])
# (2, 1, 0)


来源:https://stackoverflow.com/questions/56952598/add-extra-dimension-to-an-axes

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!