Add blocks of values to a tensor at specific locations in PyTorch

心不动则不痛 提交于 2021-02-10 14:42:03

问题


I have a list of indices:

indx = torch.LongTensor([
    [ 0,  2,  0],
    [ 0,  2,  4],
    [ 0,  4,  0],
    [ 0, 10, 14],
    [ 1,  4,  0],
    [ 1,  8,  2],
    [ 1, 12,  0]
])

And I have a tensor of 2x2 blocks:

blocks = torch.FloatTensor([
    [[1.5818, 2.3108],
     [2.6742, 3.0024]],

    [[2.0472, 1.6651],
     [3.2807, 2.7413]],

    [[1.5587, 2.1905],
     [1.9231, 3.5083]],

    [[1.6007, 2.1426],
     [2.4802, 3.0610]],

    [[1.9087, 2.1021],
     [2.7781, 3.2282]],

    [[1.5127, 2.6322],
     [2.4233, 3.6836]],

    [[1.9645, 2.3831],
     [2.8675, 3.3770]]
])

What I want to do is to add each block at an index position to another tensor (i.e. so that it starts at that index). Let's assume that I want to add it to the following tensor:

a = torch.ones([2,18,18])

Is there any efficient way to do so? So far I came up only with:

i = 0
for b, x, y in indx:
   a[b, x:x+2, y:y+2] += blocks[i]
   i += 1

It is quite inefficient, I also tried to use index_add, but it did not work properly.


回答1:


You are looking to index on three different dimensions at the same time. I had a look around in the documentation, torch.index_add will only receive a vector as index. My hopes were on torch.scatter but it doesn't to fit well to this problem. As it turns out you can achieve this pretty easily with a little work, the most difficult parts are the setup and teardown. Please hang on tight.

I'll use a simplified example here, but the same can be applied with larger tensors.

>>> indx 
tensor([[ 0,  2,  0],
        [ 0,  2,  4],
        [ 0,  4,  0]]))

>>> blocks
tensor([[[1.5818, 2.3108],
         [2.6742, 3.0024]],

        [[2.0472, 1.6651],
         [3.2807, 2.7413]],

        [[1.5587, 2.1905],
         [1.9231, 3.5083]]])

>>> a
tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]])

The main issue here is that you are looking index with slicing. That not possible in a vectorize form. To counter that though you can convert your a tensor into 2x2 chunks. This will be particulary handy since we will be able to access sub-tensors such as a[0, 2:4, 4:6] with just a[0, 1, 2]. Since the 2:4 slice on dim=1 will be grouped together on index=1 while the 4:6 slice on dim=0 will be grouped on index=2.

First we will convert a to tensor made up of 2x2 chunks. Then we will update with blocks. Finally, we will stitch back the resulting tensor into the original shape.


1. Converting a to a 2x2-chunks tensor

You can use a combination of torch.chunk and torch.cat (not torch.dog) twice: on dim=1 and dim=2. The shape of a is (1, h, w) so we're looking for a result of shape (1, h//2, w//2, 2, 2).

To do so we will unsqueeze two axes on a:

>>> a_ = a[:, None, :, None, :]
>>> a_.shape
torch.Size([1, 1, 6, 1, 6])

Then make 3 chunks on dim=2, then concatenate on dim=1:

>>> a_row_chunks = torch.cat(torch.chunk(a_, 3, dim=2), dim=1)
>>> a_row_chunks.shape
torch.Size([1, 3, 2, 1, 6])

And make 3 chunks on dim=4, then concatenate on dim=3:

>>> a_col_chunks  = torch.cat(torch.chunk(a_row_chunks, 3, dim=4), dim=3)
>>> a_col_chunks.shape
torch.Size([1, 3, 2, 3, 2])

Finally reshape all.

>>> a_chunks = a_col_chunks.reshape(1, 3, 3, 2, 2)

Create a new index with adjusted values for our new tensor with. Essentially we divide all values by 2 except for the first column which is the index of dim=0 in a which was unchanged. There's some fiddling around with the types (in short: it has to be a float in order to divide by 2 but needs to be cast back to a long in order for the indexing to work):

>>> indx_ = indx.clone().float()
>>> indx_[:, 1:] /= 2
>>> indx_ = indx_.long()
tensor([[0, 1, 0],
        [0, 1, 2],
        [0, 2, 0]])

2. Updating with blocks

We will simply index and accumulate with:

>>> a_chunks[indx_[:, 0], indx_[:, 1], indx_[:, 2]] += blocks

3. Putting it back together

I thought that was it, but actually converting a_chunk back to a 6x6 tensor is way trickier than it seems. Apparently torch.cat can only receive a tuple. I won't go into to much detail: tuple() will only consider the first axis, as a workaround you can use torch.permute to switch the axes. This combined with two torch.cat will do:

>>> a_row_cat = torch.cat(tuple(a_chunks.permute(1, 0, 2, 3, 4)), dim=2)
>>> a_row_cat.shape
torch.Size([1, 3, 6, 2])

>>> A = torch.cat(tuple(a_row_cat.permute(1, 0, 2, 3)), dim=2)
>>> A.shape
torch.Size([1, 6, 6])

>>> A
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.5818, 2.3108, 0.0000, 0.0000, 2.0472, 1.6651],
         [2.6742, 3.0024, 0.0000, 0.0000, 3.2807, 2.7413],
         [1.5587, 2.1905, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9231, 3.5083, 0.0000, 0.0000, 0.0000, 0.0000]]])

Et voilà.


If you didn't quite get how the chunks worked. Run this:

for x in range(0, 6, 2):
    for y in range(0, 6, 2):
        a *= 0
        a[:, x:x+2, y:y+2] = 1
        print(a)

And see for yourself: each 2x2 block of 1s corresponds to a chunk in a_chunks.

So you can do the same with:

for x in range(3):
    for y in range(3):
        a_chunks *= 0
        a_chunks[:, x, y] = 1
        print(a_chunks)


来源:https://stackoverflow.com/questions/65571114/add-blocks-of-values-to-a-tensor-at-specific-locations-in-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!