numpy einsum: nested dot products

北城余情 提交于 2019-12-02 11:31:05

问题


I have two n-by-k-by-3 arrays a and b, e.g.,

import numpy as np

a = np.array([
    [
        [1, 2, 3],
        [3, 4, 5]
    ],
    [
        [4, 2, 4],
        [1, 4, 5]
    ]
    ])
b = np.array([
    [
        [3, 1, 5],
        [0, 2, 3]
    ],
    [
        [2, 4, 5],
        [1, 2, 4]
    ]
    ])

and it like to compute the dot-product of all pairs of "triplets", i.e.,

np.sum(a*b, axis=2)

A better way to do that is perhaps einsum, but I can't seem to get the indices straight.

Any hints here?


回答1:


You are loosing the third axis on those two 3D input arrays with that sum-reduction, while keeping the first two axes aligned. Thus, with np.einsum, we would have the first two strings identical alongwith the third string being identical too, but would be skipped in the output string notation signalling we are reducing along that axis for both the inputs. Thus, the solution would be -

np.einsum('ijk,ijk->ij',a,b)


来源:https://stackoverflow.com/questions/38413913/numpy-einsum-nested-dot-products

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!