直接从代码中学习tensor的一些维度变换操作:
import torch
torch.manual_seed(2020)
x = torch.rand(1, 2, 3)
print(x)
# tensor([[[0.4869, 0.1052, 0.5883],
# [0.1161, 0.4949, 0.2824]]])
print(x.view(-1, 3).size()) # torch.Size([2, 3])
print('\ntranspose:')
print(torch.transpose(x, 0, 1))
print(x.transpose(0, 1).size()) # torch.Size([2, 1, 3])
print(x.transpose(1, 2).size()) # torch.Size([1, 3, 2])
# transpose要指明待交换的维度
print('\ntorch.cat:')
y = torch.rand(1, 1, 3)
print(torch.cat((x, y), dim=1).size()) # torch.Size([1, 3, 3])
# dim指定待拼接的维度;待拼接的两个向量除了待拼接的维度,其余维度必须相等或为空
print('\ntorch.chunk:')
x_chunks = torch.chunk(x, chunks=2, dim=1) # x_chunks是一个tuple
print(x_chunks)
# (tensor([[[0.4869, 0.1052, 0.5883]]]), tensor([[[0.1161, 0.4949, 0.2824]]]))
print(x_chunks[0].size(), x_chunks[1].size())
# torch.Size([1, 1, 3]) torch.Size([1, 1, 3])
print(torch.chunk(x, 2, 2)) # 不能整除时,最后一个chunk较小
# (tensor([[[0.4869, 0.1052], [0.1161, 0.4949]]]),
# tensor([[[0.5883], [0.2824]]]))
print(torch.chunk(x, 4, 2)) # chunks大于tensor在维度dim上的值时,每个chunk均为1
# (tensor([[[0.4869], [0.1161]]]),
# tensor([[[0.1052], [0.4949]]]),
# tensor([[[0.5883], [0.2824]]]))
# torch.chunk将tensor在dim维度上划分为chunks块;
print('\ntorch.split:')
z = torch.rand(4, 6, 8)
z_split = torch.split(z, split_size_or_sections=2, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 2, 8]), torch.Size([4, 2, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=4, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 4, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=[3, 3], dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 3, 8]), torch.Size([4, 3, 8])]
# torch.split也是将tensor在指定的维度上分成若干块,不同于torch.chunk的是:
# torch.chunk指定分成几个chunk;torch.split指定每个chunk的大小
# torch.chunk和torch.split可以看作是torch.cat的反面
print('\ntorch.stack:')
a = torch.rand(2, 3, 4)
b = torch.rand(2, 3, 4)
print(torch.stack((a, b), dim=0).size()) # torch.Size([2, 2, 3, 4])
c = torch.rand(2, 3, 4)
print(torch.stack((a, b, c), dim=0).size()) # torch.Size([3, 2, 3, 4])
# torch.stack与torch.cat的区别:前者在新的维度上拼接;后者在已有的维度上拼接
print('\ntorch.squeeze:')
d = torch.rand(1, 2, 3, 1)
print(torch.squeeze(d).size()) # torch.Size([2, 3])
print(torch.squeeze(d, dim=3).size()) # torch.Size([1, 2, 3])
# torch.squeeze去掉大小为1的维度;dim默认为None,去掉所有大小为1的维度;
# 指定dim时,只去掉指定的大小为1的维度;若指定的维度大小不为1,则不起作用
print(torch.unsqueeze(d, dim=0).size()) # torch.Size([1, 1, 2, 3, 1])
print(torch.unsqueeze(d, dim=-1).size()) # torch.Size([1, 2, 3, 1, 1])
# torch.unsqueeze在指定位置增加一个大小为1的维度
来源:CSDN
作者:Stoneplay26
链接:https://blog.csdn.net/qq_28753373/article/details/104216085