How does pytorch broadcasting work?

后端 未结 1 940
终归单人心
终归单人心 2020-12-16 22:58
torch.add(torch.ones(4,1), torch.randn(4))

produces a Tensor with size: torch.Size([4,4]).

Can someone provide a logic behind

相关标签:
1条回答
  • 2020-12-16 23:47

    PyTorch broadcasting is based on numpy broadcasting semantics which can be understood by reading numpy broadcasting rules or PyTorch broadcasting guide. Expounding the concept with an example would be intuitive to understand it better. So, please see the example below:

    In [27]: t_rand
    Out[27]: tensor([ 0.23451,  0.34562,  0.45673])
    
    In [28]: t_ones
    Out[28]: 
    tensor([[ 1.],
            [ 1.],
            [ 1.],
            [ 1.]])
    

    Now for torch.add(t_rand, t_ones), visualize it like:

                   # shape of (3,)
                   tensor([ 0.23451,      0.34562,       0.45673])
          # (4, 1)          | | | |       | | | |        | | | |
          tensor([[ 1.],____+ | | |   ____+ | | |    ____+ | | |
                  [ 1.],______+ | |   ______+ | |    ______+ | |
                  [ 1.],________+ |   ________+ |    ________+ |
                  [ 1.]])_________+   __________+    __________+
    

    which should give the output with tensor of shape (4,3) as:

    # shape of (4,3)
    In [33]: torch.add(t_rand, t_ones)
    Out[33]: 
    tensor([[ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673]])
    

    Also, note that we get exactly the same result even if we pass the arguments in a reverse order as compared to the previous one:

    # shape of (4, 3)
    In [34]: torch.add(t_ones, t_rand)
    Out[34]: 
    tensor([[ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673]])
    

    Anyway, I prefer the former way of understanding for more straightforward intuitiveness.


    For pictorial understanding, I culled out more examples which are enumerated below:

    Example-1:


    Example-2::

    T and F stand for True and False respectively and indicate along which dimensions we allow broadcasting (source: Theano).


    Example-3:

    Here are some shapes where the array b is broadcasted appropriately to match the shape of the array a.

    0 讨论(0)
提交回复
热议问题