tensor.squeeze(dim)
作用: 如果dim指定的维度的值为1,则将该维度删除,若指定的维度值不为1,则返回原来的tensor
例子:
x = torch.rand(2,1,3)
print(x)
print(x.squeeze(1))
print(x.squeeze(2))
输出:
tensor([[[0.7031, 0.7173, 0.0606]],
[[0.6884, 0.4072, 0.0516]]])
tensor([[0.7031, 0.7173, 0.0606],
[0.6884, 0.4072, 0.0516]])
tensor([[[0.7031, 0.7173, 0.0606]],
[[0.6884, 0.4072, 0.0516]]])
如上结果所示:x.shape=[2, 1, 3] , 第一维度的值为1, 因此x.squeeze(dim=1)的输出会将第一维度去掉,其输出shape=[2,3], 第二维度值不为1, 因此x.squeeze(dim=2)输出tensor的shape不变
tensor.unsqueeze(dim)
这个函数主要是对数据维度进行扩充。给指定位置加上维数为1的维度,比如原本有个三行的数据(3,),在0的位置加了一维就变成一行三列(1,3)。还有一种形式就是b=torch.squeeze(tensor,dim) 就是在tensor中指定位置 dim 加上一个维数为1的维度
例子:
x = torch.rand(2,3)
print(x)
print("x.shape:", x.shape)
y = torch.unsqueeze(x, 1)
print(y)
print("y.shape:", y.shape)
z = x.unsqueeze(2)
print(z)
print("z.shape:", z.shape)
输出:
tensor([[0.1255, 0.7249, 0.5253],
[0.9247, 0.4592, 0.3944]])
x.shape: torch.Size([2, 3])
tensor([[[0.1255, 0.7249, 0.5253]],
[[0.9247, 0.4592, 0.3944]]])
y.shape: torch.Size([2, 1, 3])
tensor([[[0.1255],
[0.7249],
[0.5253]],
[[0.9247],
[0.4592],
[0.3944]]])
z.shape: torch.Size([2, 3, 1])
[Finished in 2.6s]
来源:CSDN
作者:orangerfun
链接:https://blog.csdn.net/orangerfun/article/details/104012564