numpy einsum: nested dot products

前端 未结 1 2040
太阳男子
太阳男子 2021-01-23 21:49

I have two n-by-k-by-3 arrays a and b, e.g.,

import numpy as np

a = np.array([
    [
                 


        
相关标签:
1条回答
  • 2021-01-23 22:09

    You are loosing the third axis on those two 3D input arrays with that sum-reduction, while keeping the first two axes aligned. Thus, with np.einsum, we would have the first two strings identical alongwith the third string being identical too, but would be skipped in the output string notation signalling we are reducing along that axis for both the inputs. Thus, the solution would be -

    np.einsum('ijk,ijk->ij',a,b)
    
    0 讨论(0)
提交回复
热议问题