Как работает NumPy Sum (с осью)?

Я взял на себя все усилия, чтобы узнать, как NumPy работает для моего собственного любопытства.

Кажется, что простейшая функция сложнее перевести на код (я понимаю по коду). Легко жестко кодировать каждую ось для каждого случая, но я хочу найти динамический алгоритм, который может суммировать любую ось с n-измерениями. Документация на официальном веб-сайте не является полезной (она показывает только результат, а не процесс), и трудно перемещаться по Python/C-коду.

Примечание: Я выяснил, что когда массив суммируется, указанная ось "удаляется", т.е. сумма массива с формой (4, 3, 2) с осью 1 дает ответ массива с формой (4, 2)

Ответ 1

Настройка

рассмотрим массив numpy a

a = np.arange(30).reshape(2, 3, 5)
print(a)

[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]]

Где размеры?

Размеры и позиции выделены следующим образом

            p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

     dim 2  0  1  2  3  4

            |  |  |  |  |
  dim 0     ↓  ↓  ↓  ↓  ↓
  ----> [[[ 0  1  2  3  4]   <---- dim 1, pos 0
  pos 0   [ 5  6  7  8  9]   <---- dim 1, pos 1
          [10 11 12 13 14]]  <---- dim 1, pos 2
  dim 0
  ---->  [[15 16 17 18 19]   <---- dim 1, pos 0
  pos 1   [20 21 22 23 24]   <---- dim 1, pos 1
          [25 26 27 28 29]]] <---- dim 1, pos 2
            ↑  ↑  ↑  ↑  ↑
            |  |  |  |  |

     dim 2  p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

            0  1  2  3  4

Примеры размеров:

Это становится более понятным с помощью нескольких примеров

a[0, :, :] # dim 0, pos 0

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]

a[:, 1, :] # dim 1, pos 1

[[ 5  6  7  8  9]
 [20 21 22 23 24]]

a[:, :, 3] # dim 2, pos 3

[[ 3  8 13]
 [18 23 28]]

sum

объяснение sum и axis
a.sum(0) - сумма всех срезов вдоль dim 0

a.sum(0)

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

то же, что и

a[0, :, :] + \
a[1, :, :]

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

a.sum(1) - сумма всех срезов вдоль dim 1

a.sum(1)

[[15 18 21 24 27]
 [60 63 66 69 72]]

то же, что и

a[:, 0, :] + \
a[:, 1, :] + \
a[:, 2, :]

[[15 18 21 24 27]
 [60 63 66 69 72]]

a.sum(2) - сумма всех срезов вдоль dim 2

a.sum(2)

[[ 10  35  60]
 [ 85 110 135]]

то же, что и

a[:, :, 0] + \
a[:, :, 1] + \
a[:, :, 2] + \
a[:, :, 3] + \
a[:, :, 4]

[[ 10  35  60]
 [ 85 110 135]]

по умолчанию ось -1
это означает все оси. или суммировать все числа.

a.sum()

435

Ответ 2

Я использую операцию вложенного цикла, чтобы объяснить это.

import numpy as np

n = np.array(
[[[1, 2, 3],
 [4, 5, 6],
 [7, 8, 9]],

 [[2, 4, 6],
 [8, 10, 12],
 [14, 16, 18]],

 [[1, 3, 5],
 [7, 9, 11],
 [13, 15, 17]]])

print(n)

print("============ sum axis=None=============")

sum = 0
for i in range(3):
  for j in range(3): 
    for k in range(3):
      sum += n[k][i][j]
print(sum) # 216

print('------------------')
print(np.sum(n))  # 216
print("============ sum axis=0 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[axis][i][j]
    print(sum,end=' ')
  print()

print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[1][0][0] + n[2][0][0]))
print("sum[1][1] = %d" % (n[0][1][1] + n[1][1][1] + n[2][1][1]))
print("sum[2][2] = %d" % (n[0][2][2] + n[1][2][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=0)) 
print("============ sum axis=1 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][axis][j]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][1][0] + n[0][2][0]))
print("sum[1][1] = %d" % (n[1][0][1] + n[1][1][1] + n[1][2][1]))
print("sum[2][2] = %d" % (n[2][0][2] + n[2][1][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=1))  
print("============ sum axis=2 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][j][axis]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][0][1] + n[0][0][2]))
print("sum[1][1] = %d" % (n[1][1][0] + n[1][1][1] + n[1][1][2]))
print("sum[2][2] = %d" % (n[2][2][0] + n[2][2][1] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=2))
print("============ sum axis=(0,1)) =============") 
for i in range(3):
  sum = 0
  for axis1 in range(3):   
    for axis2 in range(3):
      sum += n[axis1][axis2][i]
  print(sum,end=' ')

print()
print('------------------')
print("sum[1] = %d" % (n[0][0][1] + n[0][1][1] + n[0][2][1] +
              n[1][0][1] + n[1][1][1] + n[1][2][1] +
              n[2][0][1] + n[2][1][1] + n[2][2][1] ))
print('------------------')
print(np.sum(n, axis=(0,1)))

результат:

[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]

 [[ 2  4  6]
  [ 8 10 12]
  [14 16 18]]

 [[ 1  3  5]
  [ 7  9 11]
  [13 15 17]]]
============ sum axis=None=============
216
------------------
216
============ sum axis=0 =============
4 9 14 
19 24 29 
34 39 44 
------------------
sum[0][0] = 4
sum[1][1] = 24
sum[2][2] = 44
------------------
[[ 4  9 14]
 [19 24 29]
 [34 39 44]]
============ sum axis=1 =============
12 15 18 
24 30 36 
21 27 33 
------------------
sum[0][0] = 12
sum[1][1] = 30
sum[2][2] = 33
------------------
[[12 15 18]
 [24 30 36]
 [21 27 33]]
============ sum axis=2 =============
6 15 24 
12 30 48 
9 27 45 
------------------
sum[0][0] = 6
sum[1][1] = 30
sum[2][2] = 45
------------------
[[ 6 15 24]
 [12 30 48]
 [ 9 27 45]]
============ sum axis=(0,1)) =============
57 72 87 
------------------
sum[1] = 72
------------------
[57 72 87]