问题
I have a batch of segmentation masks of shape [5,1,100,100]
(batch_size x dims x ht x wd
) which I have to display in tensorboardX with an RGB image batch [5,3,100,100]
. I want to add two dummy dimensions to the second axes of the segmentation mask to make it [5,3,100,100]
so there will not be any dimension mismatch error when I pass it to torch.utils.make_grid
. I have tried unsqueeze
, expand
and view
but I am not able to do it. Any suggestions?
回答1:
You can use expand
, repeat
, or repeat_interleave
:
import torch
x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)
print(x1_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape) # torch.Size([5, 3, 100, 100])
Note that, as stated in the docs:
- torch.expand():
Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.
- torch.repeat():
Unlike
expand()
, this function copies the tensor’s data.
回答2:
Expand is a method that I keep telling myself don't read the docs, where it reads:
Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor
Since there is nothing like view in PyTorch, at least I never saw them as objects so they are not created. The only thing there is: is stride.
And expand can also shrink.
t21 = torch.rand(2,1)
print(t)
print(t.shape)
print(t.stride())
t25 = t.expand(-1,5)
print(t25.shape)
print(t25)
print(t25.stride())
t123 = t.expand(1,-1,3)
print(t123.shape)
print(t123)
print(t123.stride())
# tensor([[0.1353],
# [0.5809]])
# torch.Size([2, 1])
# (1, 1)
# torch.Size([2, 5])
# tensor([[0.1353, 0.1353, 0.1353, 0.1353, 0.1353],
# [0.5809, 0.5809, 0.5809, 0.5809, 0.5809]])
# (1, 0)
# torch.Size([1, 2, 3])
# tensor([[[0.1353, 0.1353, 0.1353],
# [0.5809, 0.5809, 0.5809]]])
# (2, 1, 0)
来源:https://stackoverflow.com/questions/56952598/add-extra-dimension-to-an-axes