问题
I have a list of indices:
indx = torch.LongTensor([
[ 0, 2, 0],
[ 0, 2, 4],
[ 0, 4, 0],
[ 0, 10, 14],
[ 1, 4, 0],
[ 1, 8, 2],
[ 1, 12, 0]
])
And I have a tensor of 2x2
blocks:
blocks = torch.FloatTensor([
[[1.5818, 2.3108],
[2.6742, 3.0024]],
[[2.0472, 1.6651],
[3.2807, 2.7413]],
[[1.5587, 2.1905],
[1.9231, 3.5083]],
[[1.6007, 2.1426],
[2.4802, 3.0610]],
[[1.9087, 2.1021],
[2.7781, 3.2282]],
[[1.5127, 2.6322],
[2.4233, 3.6836]],
[[1.9645, 2.3831],
[2.8675, 3.3770]]
])
What I want to do is to add each block at an index position to another tensor (i.e. so that it starts at that index). Let's assume that I want to add it to the following tensor:
a = torch.ones([2,18,18])
Is there any efficient way to do so? So far I came up only with:
i = 0
for b, x, y in indx:
a[b, x:x+2, y:y+2] += blocks[i]
i += 1
It is quite inefficient, I also tried to use index_add
, but it did not work properly.
回答1:
You are looking to index on three different dimensions at the same time. I had a look around in the documentation, torch.index_add will only receive a vector as index. My hopes were on torch.scatter but it doesn't to fit well to this problem. As it turns out you can achieve this pretty easily with a little work, the most difficult parts are the setup and teardown. Please hang on tight.
I'll use a simplified example here, but the same can be applied with larger tensors.
>>> indx
tensor([[ 0, 2, 0],
[ 0, 2, 4],
[ 0, 4, 0]]))
>>> blocks
tensor([[[1.5818, 2.3108],
[2.6742, 3.0024]],
[[2.0472, 1.6651],
[3.2807, 2.7413]],
[[1.5587, 2.1905],
[1.9231, 3.5083]]])
>>> a
tensor([[[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]]])
The main issue here is that you are looking index with slicing. That not possible in a vectorize form. To counter that though you can convert your a
tensor into 2x2
chunks. This will be particulary handy since we will be able to access sub-tensors such as a[0, 2:4, 4:6]
with just a[0, 1, 2]
. Since the 2:4
slice on dim=1
will be grouped together on index=1
while the 4:6
slice on dim=0
will be grouped on index=2
.
First we will convert a
to tensor made up of 2x2
chunks. Then we will update with blocks
. Finally
, we will stitch back the resulting tensor into the original shape.
1. Converting a
to a 2x2
-chunks tensor
You can use a combination of torch.chunk and torch.cat (not torch.dog
) twice: on dim=1
and dim=2
. The shape of a
is (1, h, w)
so we're looking for a result of shape (1, h//2, w//2, 2, 2)
.
To do so we will unsqueeze two axes on a
:
>>> a_ = a[:, None, :, None, :]
>>> a_.shape
torch.Size([1, 1, 6, 1, 6])
Then make 3 chunks on dim=2
, then concatenate on dim=1
:
>>> a_row_chunks = torch.cat(torch.chunk(a_, 3, dim=2), dim=1)
>>> a_row_chunks.shape
torch.Size([1, 3, 2, 1, 6])
And make 3 chunks on dim=4
, then concatenate on dim=3
:
>>> a_col_chunks = torch.cat(torch.chunk(a_row_chunks, 3, dim=4), dim=3)
>>> a_col_chunks.shape
torch.Size([1, 3, 2, 3, 2])
Finally reshape all.
>>> a_chunks = a_col_chunks.reshape(1, 3, 3, 2, 2)
Create a new index with adjusted values for our new tensor with. Essentially we divide all values by 2 except for the first column which is the index of dim=0
in a
which was unchanged. There's some fiddling around with the types (in short: it has to be a float in order to divide by 2 but needs to be cast back to a long in order for the indexing to work):
>>> indx_ = indx.clone().float()
>>> indx_[:, 1:] /= 2
>>> indx_ = indx_.long()
tensor([[0, 1, 0],
[0, 1, 2],
[0, 2, 0]])
2. Updating with blocks
We will simply index and accumulate with:
>>> a_chunks[indx_[:, 0], indx_[:, 1], indx_[:, 2]] += blocks
3. Putting it back together
I thought that was it, but actually converting a_chunk
back to a 6x6
tensor is way trickier than it seems. Apparently torch.cat
can only receive a tuple. I won't go into to much detail: tuple()
will only consider the first axis, as a workaround you can use torch.permute to switch the axes. This combined with two torch.cat
will do:
>>> a_row_cat = torch.cat(tuple(a_chunks.permute(1, 0, 2, 3, 4)), dim=2)
>>> a_row_cat.shape
torch.Size([1, 3, 6, 2])
>>> A = torch.cat(tuple(a_row_cat.permute(1, 0, 2, 3)), dim=2)
>>> A.shape
torch.Size([1, 6, 6])
>>> A
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.5818, 2.3108, 0.0000, 0.0000, 2.0472, 1.6651],
[2.6742, 3.0024, 0.0000, 0.0000, 3.2807, 2.7413],
[1.5587, 2.1905, 0.0000, 0.0000, 0.0000, 0.0000],
[1.9231, 3.5083, 0.0000, 0.0000, 0.0000, 0.0000]]])
Et voilà.
If you didn't quite get how the chunks worked. Run this:
for x in range(0, 6, 2):
for y in range(0, 6, 2):
a *= 0
a[:, x:x+2, y:y+2] = 1
print(a)
And see for yourself: each 2x2
block of 1
s corresponds to a chunk in a_chunks
.
So you can do the same with:
for x in range(3):
for y in range(3):
a_chunks *= 0
a_chunks[:, x, y] = 1
print(a_chunks)
来源:https://stackoverflow.com/questions/65571114/add-blocks-of-values-to-a-tensor-at-specific-locations-in-pytorch