Multiply several matrices in numpy

前端 未结 5 2141
借酒劲吻你
借酒劲吻你 2021-02-02 06:52

Suppose you have n square matrices A1,...,An. Is there anyway to multiply these matrices in a neat way? As far as I know dot in numpy accepts only two arguments. One obvious way

5条回答
  •  走了就别回头了
    2021-02-02 07:18

    Another way to achieve this would be using einsum, which implements the Einstein summation convention for NumPy.

    To very briefly explain this convention with respect to this problem: When you write down your multiple matrix product as one big sum of products, you get something like:

    P_im = sum_j sum_k sum_l A1_ij A2_jk A3_kl A4_lm
    

    where P is the result of your product and A1, A2, A3, and A4 are the input matrices. Note that you sum over exactly those indices that appear twice in the summand, namely j, k, and l. As a sum with this property often appears in physics, vector calculus, and probably some other fields, there is a NumPy tool for it, namely einsum.

    In the above example, you can use it to calculate your matrix product as follows:

    P = np.einsum( "ij,jk,kl,lm", A1, A2, A3, A4 )
    

    Here, the first argument tells the function which indices to apply to the argument matrices and then all doubly appearing indices are summed over, yielding the desired result.

    Note that the computational efficiency depends on several factors (so you are probably best off with just testing it):

    • Why is numpy's einsum slower than numpy's built-in functions?
    • Why is numpy's einsum faster than numpy's built in functions?

提交回复
热议问题