What is the role of keepdims in Numpy (Python)?

后端 未结 4 1187
北海茫月
北海茫月 2021-02-10 10:26

When I use np.sum, I encountered a parameter called keepdims. After looking up the docs, I still cannot understand the meaning of keepdims

4条回答
  •  無奈伤痛
    2021-02-10 11:05

    Consider a small 2d array:

    In [180]: A=np.arange(12).reshape(3,4)
    In [181]: A
    Out[181]: 
    array([[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11]])
    

    Sum across rows; the result is a (3,) array

    In [182]: A.sum(axis=1)
    Out[182]: array([ 6, 22, 38])
    

    But to sum (or divide) A by the sum requires reshaping

    In [183]: A-A.sum(axis=1)
    ...
    ValueError: operands could not be broadcast together with shapes (3,4) (3,) 
    In [184]: A-A.sum(axis=1)[:,None]   # turn sum into (3,1)
    Out[184]: 
    array([[ -6,  -5,  -4,  -3],
           [-18, -17, -16, -15],
           [-30, -29, -28, -27]])
    

    If I use keepdims, "the result will broadcast correctly against" A.

    In [185]: A.sum(axis=1, keepdims=True)   # (3,1) array
    Out[185]: 
    array([[ 6],
           [22],
           [38]])
    In [186]: A-A.sum(axis=1, keepdims=True)
    Out[186]: 
    array([[ -6,  -5,  -4,  -3],
           [-18, -17, -16, -15],
           [-30, -29, -28, -27]])
    

    If I sum the other way, I don't need the keepdims. Broadcasting this sum is automatic: A.sum(axis=0)[None,:]. But there's no harm in using keepdims.

    In [190]: A.sum(axis=0)
    Out[190]: array([12, 15, 18, 21])    # (4,)
    In [191]: A-A.sum(axis=0)
    Out[191]: 
    array([[-12, -14, -16, -18],
           [ -8, -10, -12, -14],
           [ -4,  -6,  -8, -10]])
    

    If you prefer, these actions might make more sense with np.mean, normalizing the array over columns or rows. In any case it can simplify further math between the original array and the sum/mean.

提交回复
热议问题