张量的操作:拼接、切分、索引和变换
1 张量的操作:拼接、切分、索引和变换
一 张量的拼接和切分
1.1 torch.cat()
功能:将张量按维度dim进行拼接
tensor:张量序列
dim:拼接维度
t=torch.ones((2,3))
torch.cat([t,t],dim=0)
torch.cat([t,t],dim=1)
torch.cat([t,t,t],dim=1)
1.2 torch.stack()
功能:在新创建的维度dim上进行拼接
tensor:张量序列
dim:要拼接的维度
与cat相比,stack创建在了一个新维度
1.3 torch.chunk()
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意:若不能整除,最后一份张量小于其他张量
input:要切分的张量
chunks:要切分的份数
dim:要切分的维度
a = torch.ones((2, 7)) # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
1.4 torch.split()
功能:将张量按维度dim进行切分
返回值:张量列表
split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, 2], dim=1) # [2 , 1, 2]
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, [2, 1, 1], dim=1) # [2 , 1, 2]
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
list内元素之和等于维度上的长度
二张量的索引
2.1 torch.index_select()
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
input:要索引的张量
dim:要索引的维度
index:要索引数据的序号
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long) # float
t_select = torch.index_select(t, dim=0, index=idx)
print("t:\n{}\nt_select:\n{}".format(t, t_select))
2.2 torch.masked_select()
功能:按mmask中的True进行索引
返回值:一维张量
input:要索引的张量
mask:与input同形状的布尔类型张量
t = torch.randint(0, 9, size=(3, 3))
mask = t.le(5) # ge is mean greater than or equal/ gt: greater than le lt
t_select = torch.masked_select(t, mask)
print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
三 张量变换
torch.reshape()
功能:变换张量的形状
注意:当张量在内存中是连续时,新张量与input共享内存
input:要变换的张量
shape:新张量的形状
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2)) # -1
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
共享内存
3.2 torch.transpose().
功能:交换维度
3.3 torch.t()
功能:2维张量转置
3.4 torch.squeeze()
功能:压缩长度为1的维度(轴)
dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,课被移除
3.5 torch.unsqueeze()
功能:依据dim扩展维度
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)
来源:CSDN
作者:Major_s
链接:https://blog.csdn.net/qq_41375318/article/details/103780104