1. repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None)
参数说明:
self: 传入的数据为tensor
repeats: 复制的份数
dim: 要复制的维度,可设定为0/1/2.....
2. 例子
2.1 Code
此处定义了一个4维tensor,要对第2个维度复制,由原来的1变为3,即将设定dim=1。
1 import torch
2
3
4 def function():
5 data1 = torch.rand([2, 1, 3, 3])
6 print("data1_shape: ", data1.shape)
7 print("data1: ", data1)
8
9 data2 = torch.repeat_interleave(data1, repeats=3, dim=1)
10 print("data2_shape: ", data2.shape)
11 print("data2: ", data2)
12
13
14 if __name__ == '__main__':
15 function()
2.2 输出显示
即可看到输入tensor形状为[2, 1, 3, 3],经过repeat后,tensor变为[2, 3, 3, 3],并在第二维度上保持相同的数据。
来源:oschina
链接:https://my.oschina.net/u/4299953/blog/4261813