How does NumPy Sum (with axis) work?

前端 未结 2 1583
一个人的身影
一个人的身影 2020-12-07 09:44

I\'ve taken it upon myself to learn how NumPy works for my own curiosity.

It seems that the simplest function is the hardest to translate to code (I un

相关标签:
2条回答
  • 2020-12-07 09:54

    I use a nested loop operation to explain it.

    import numpy as np
    
    n = np.array(
    [[[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9]],
    
     [[2, 4, 6],
     [8, 10, 12],
     [14, 16, 18]],
    
     [[1, 3, 5],
     [7, 9, 11],
     [13, 15, 17]]])
    
    print(n)
    
    print("============ sum axis=None=============")
    
    sum = 0
    for i in range(3):
      for j in range(3): 
        for k in range(3):
          sum += n[k][i][j]
    print(sum) # 216
    
    print('------------------')
    print(np.sum(n))  # 216
    print("============ sum axis=0 =============") 
    for i in range(3):
      for j in range(3):
        sum = 0
        for axis in range(3):
          sum += n[axis][i][j]
        print(sum,end=' ')
      print()
    
    print('------------------')
    print("sum[0][0] = %d" % (n[0][0][0] + n[1][0][0] + n[2][0][0]))
    print("sum[1][1] = %d" % (n[0][1][1] + n[1][1][1] + n[2][1][1]))
    print("sum[2][2] = %d" % (n[0][2][2] + n[1][2][2] + n[2][2][2]))
    print('------------------')
    print(np.sum(n, axis=0)) 
    print("============ sum axis=1 =============") 
    for i in range(3):
      for j in range(3):
        sum = 0
        for axis in range(3):
          sum += n[i][axis][j]
        print(sum,end=' ')
      print()
    print('------------------')
    print("sum[0][0] = %d" % (n[0][0][0] + n[0][1][0] + n[0][2][0]))
    print("sum[1][1] = %d" % (n[1][0][1] + n[1][1][1] + n[1][2][1]))
    print("sum[2][2] = %d" % (n[2][0][2] + n[2][1][2] + n[2][2][2]))
    print('------------------')
    print(np.sum(n, axis=1))  
    print("============ sum axis=2 =============") 
    for i in range(3):
      for j in range(3):
        sum = 0
        for axis in range(3):
          sum += n[i][j][axis]
        print(sum,end=' ')
      print()
    print('------------------')
    print("sum[0][0] = %d" % (n[0][0][0] + n[0][0][1] + n[0][0][2]))
    print("sum[1][1] = %d" % (n[1][1][0] + n[1][1][1] + n[1][1][2]))
    print("sum[2][2] = %d" % (n[2][2][0] + n[2][2][1] + n[2][2][2]))
    print('------------------')
    print(np.sum(n, axis=2))
    print("============ sum axis=(0,1)) =============") 
    for i in range(3):
      sum = 0
      for axis1 in range(3):   
        for axis2 in range(3):
          sum += n[axis1][axis2][i]
      print(sum,end=' ')
    
    print()
    print('------------------')
    print("sum[1] = %d" % (n[0][0][1] + n[0][1][1] + n[0][2][1] +
                  n[1][0][1] + n[1][1][1] + n[1][2][1] +
                  n[2][0][1] + n[2][1][1] + n[2][2][1] ))
    print('------------------')
    print(np.sum(n, axis=(0,1)))
    

    result:

    [[[ 1  2  3]
      [ 4  5  6]
      [ 7  8  9]]
    
     [[ 2  4  6]
      [ 8 10 12]
      [14 16 18]]
    
     [[ 1  3  5]
      [ 7  9 11]
      [13 15 17]]]
    ============ sum axis=None=============
    216
    ------------------
    216
    ============ sum axis=0 =============
    4 9 14 
    19 24 29 
    34 39 44 
    ------------------
    sum[0][0] = 4
    sum[1][1] = 24
    sum[2][2] = 44
    ------------------
    [[ 4  9 14]
     [19 24 29]
     [34 39 44]]
    ============ sum axis=1 =============
    12 15 18 
    24 30 36 
    21 27 33 
    ------------------
    sum[0][0] = 12
    sum[1][1] = 30
    sum[2][2] = 33
    ------------------
    [[12 15 18]
     [24 30 36]
     [21 27 33]]
    ============ sum axis=2 =============
    6 15 24 
    12 30 48 
    9 27 45 
    ------------------
    sum[0][0] = 6
    sum[1][1] = 30
    sum[2][2] = 45
    ------------------
    [[ 6 15 24]
     [12 30 48]
     [ 9 27 45]]
    ============ sum axis=(0,1)) =============
    57 72 87 
    ------------------
    sum[1] = 72
    ------------------
    [57 72 87]
    
    0 讨论(0)
  • 2020-12-07 10:14

    Setup

    consider the numpy array a

    a = np.arange(30).reshape(2, 3, 5)
    print(a)
    
    [[[ 0  1  2  3  4]
      [ 5  6  7  8  9]
      [10 11 12 13 14]]
    
     [[15 16 17 18 19]
      [20 21 22 23 24]
      [25 26 27 28 29]]]
    

    Where are the dimensions?

    The dimensions and positions are highlighted by the following

                p  p  p  p  p
                o  o  o  o  o
                s  s  s  s  s
    
         dim 2  0  1  2  3  4
    
                |  |  |  |  |
      dim 0     ↓  ↓  ↓  ↓  ↓
      ----> [[[ 0  1  2  3  4]   <---- dim 1, pos 0
      pos 0   [ 5  6  7  8  9]   <---- dim 1, pos 1
              [10 11 12 13 14]]  <---- dim 1, pos 2
      dim 0
      ---->  [[15 16 17 18 19]   <---- dim 1, pos 0
      pos 1   [20 21 22 23 24]   <---- dim 1, pos 1
              [25 26 27 28 29]]] <---- dim 1, pos 2
                ↑  ↑  ↑  ↑  ↑
                |  |  |  |  |
    
         dim 2  p  p  p  p  p
                o  o  o  o  o
                s  s  s  s  s
    
                0  1  2  3  4
    

    Dimension examples:

    This becomes more clear with a few examples

    a[0, :, :] # dim 0, pos 0
    
    [[ 0  1  2  3  4]
     [ 5  6  7  8  9]
     [10 11 12 13 14]]
    

    a[:, 1, :] # dim 1, pos 1
    
    [[ 5  6  7  8  9]
     [20 21 22 23 24]]
    

    a[:, :, 3] # dim 2, pos 3
    
    [[ 3  8 13]
     [18 23 28]]
    

    sum

    explanation of sum and axis
    a.sum(0) is the sum of all slices along dim 0

    a.sum(0)
    
    [[15 17 19 21 23]
     [25 27 29 31 33]
     [35 37 39 41 43]]
    

    same as

    a[0, :, :] + \
    a[1, :, :]
    
    [[15 17 19 21 23]
     [25 27 29 31 33]
     [35 37 39 41 43]]
    

    a.sum(1) is the sum of all slices along dim 1

    a.sum(1)
    
    [[15 18 21 24 27]
     [60 63 66 69 72]]
    

    same as

    a[:, 0, :] + \
    a[:, 1, :] + \
    a[:, 2, :]
    
    [[15 18 21 24 27]
     [60 63 66 69 72]]
    

    a.sum(2) is the sum of all slices along dim 2

    a.sum(2)
    
    [[ 10  35  60]
     [ 85 110 135]]
    

    same as

    a[:, :, 0] + \
    a[:, :, 1] + \
    a[:, :, 2] + \
    a[:, :, 3] + \
    a[:, :, 4]
    
    [[ 10  35  60]
     [ 85 110 135]]
    

    default axis is -1
    this means all axes. or sum all numbers.

    a.sum()
    
    435
    
    0 讨论(0)
提交回复
热议问题