import torch
where
- where(condition,a,b) 满足条件,返回a里面的对应元素,不满足条件,返回b里面对应的元素
a = torch.rand(2,3)
b = torch.rand(2,3)
torch.where(a>b,a,b)
tensor([[0.2254, 0.7619, 0.9761],
[0.7787, 0.4238, 0.8476]])
gather
-
设X的size是P·Q·M
-
Z=X.gather(dim=0,index=Y) Z是一个size和Y一样的tensor
-
对Y的要求是:Y的size=N·Q·M N是任意正整数,对构成Y的元素的取值范围是[0,P-1]
-
现在有一个tensor W,W的size是Q·M,Wqm为X1qm,X2qm…Xnqmzip在一起的list,即Wqm=(X1qm,X2qm…Xnqm)
-
- step1:按照dim的取值对X中的数据进行zip,形成W
-
- step2:形成index矩阵Y
-
- step3:依照Y,从W中的每个小zip中取值
- step3:依照Y,从W中的每个小zip中取值
例1:
test=torch.randint(0,11,[2,3,4])
test
tensor([[[ 6, 2, 10, 1],
[ 6, 0, 0, 9],
[ 2, 7, 6, 5]],
[[ 1, 4, 8, 5],
[ 0, 10, 0, 0],
[ 4, 6, 8, 9]]])
index = torch.randint(0,2,[3,3,4])
index
tensor([[[0, 0, 1, 0],
[0, 1, 1, 1],
[1, 0, 0, 1]],
[[0, 1, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 1]],
[[0, 1, 0, 0],
[1, 1, 0, 0],
[1, 0, 1, 1]]])
test.gather(dim=0,index=index)
tensor([[[ 6, 2, 8, 1],
[ 6, 10, 0, 0],
[ 4, 7, 6, 9]],
[[ 6, 4, 10, 5],
[ 6, 0, 0, 9],
[ 2, 7, 6, 9]],
[[ 6, 4, 10, 1],
[ 0, 10, 0, 9],
[ 4, 7, 8, 9]]])
例2
a = torch.randint(1,11,[2,3,4])
a
tensor([[[ 5, 10, 5, 5],
[10, 8, 4, 2],
[ 6, 5, 1, 9]],
[[ 1, 7, 3, 3],
[ 8, 5, 8, 2],
[ 7, 5, 10, 1]]])
index = torch.randint(0,3,[2,2,4])
index
tensor([[[0, 0, 0, 0],
[1, 0, 1, 1]],
[[1, 2, 2, 0],
[2, 2, 2, 0]]])
a.gather(dim=1,index=index)
#这里dim = 1,所以把(5,10,6)zip在一起,(10,8,5)zip在一起...(3,2,1)zip在一起,然后按顺序把这些小zip排成2*4的tensor
#然后拿着index在这个tensor中选
tensor([[[ 5, 10, 5, 5],
[10, 10, 4, 2]],
[[ 8, 5, 10, 3],
[ 7, 5, 10, 3]]])
例3
a = torch.randint(1,11,[2,3,3,2])
a
tensor([[[[10, 7],
[ 9, 7],
[ 9, 1]],
[[10, 5],
[ 4, 5],
[ 3, 7]],
[[10, 8],
[ 4, 10],
[ 6, 1]]],
[[[ 1, 4],
[ 3, 2],
[ 1, 7]],
[[ 5, 2],
[ 1, 4],
[ 9, 10]],
[[ 2, 5],
[ 5, 5],
[ 7, 4]]]])
index=torch.randint(0,2,[2,3,3,1])
index
tensor([[[[0],
[1],
[0]],
[[0],
[0],
[1]],
[[1],
[1],
[1]]],
[[[1],
[0],
[1]],
[[1],
[0],
[0]],
[[0],
[0],
[0]]]])
a.gather(dim=3,index=index)
#dim=3 所以把(10,7)zip在一起,(9,7)zip在一起,(9,1)zip在一起,...(7,4)zip在一起
#...
tensor([[[[10],
[ 7],
[ 9]],
[[10],
[ 4],
[ 7]],
[[ 8],
[10],
[ 1]]],
[[[ 4],
[ 3],
[ 7]],
[[ 2],
[ 1],
[ 9]],
[[ 2],
[ 5],
[ 7]]]])
来源:CSDN
作者:缦旋律
链接:https://blog.csdn.net/weixin_41391619/article/details/104699555