2 Python的 * 和 NumPy的广播
几天前,一个小伙伴问:Python的 * 和广播机制是一回事吗?它们相似,但实则不同!
1) 先了解下python中的 list * 标量,结果是复制对应的元素,如下所示:
a = [3,8,10]
print(a*3)
[3, 8, 10, 3, 8, 10, 3, 8, 10]
a = [[1,3,2],[6,4,3]]
print(a*2)
[[1, 3, 2], [6, 4, 3], [1, 3, 2], [6, 4, 3]]
list * 标量等于按元素或按行的复制。
2) NumPy 的广播机制,先看一个例子,如下:
x = np.arange(4)
xx = x.reshape(4,1)
print(xx)
y = np.ones(5)
print(y)
print(xx + y)
xx: array([ [0], [1], [2], [3] ])
y : array([ 1, 1, 1, 1, 1 ])
xx + y : array(array([[ 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4.]]) )
上面的例子,xx 的 shape是(4,1),y的 shape(5,),NumPy支持的这类操作,被称为广播机制。
3 NumPy广播 通用规则
注意,不是任意形状间的ndarray都能做广播,必须满足一定的约束条件。对两个NumPy的 ndarray 进行操作时,NumPy 会比较形状,开始于最靠后的维度(如5*4*6,最靠后的维度长度是6)。当以下情形出现时,维度是兼容的:
1) 相等
data_2d = np.arange(20).reshape(5,4) # 5 * 4
data_1d = np.array([1,2,3,4]) # 4
print(data_2d * data_1d) # 5 * 4
array([[ 0, 2, 6, 12],
[ 4, 10, 18, 28],
[ 8, 18, 30, 44],
[12, 26, 42, 60],
[16, 34, 54, 76]])
2) 其中一个长度为 1
data_3d = np.arange(20).reshape(5,2,2) # 5 * 2 * 2
data_1d = np.array([3]) # 1
print(data_3d * data_1d) # 5 * 2 * 2
[[[ 0 3]
[ 6 9]]
[[12 15]
[18 21]]
[[24 27]
[30 33]]
[[36 39]
[42 45]]
[[48 51]
[54 57]]]
可以看到,广播是按照右对齐的方式,其中长度为1的维度被自动广播。
4 NumPy广播 好处
先看一个例子。一个ndarray和一个标量相乘,这是广播机制:
a = np.array([1, 2, 3])
b = 2
print(a * b)
array([2, 4, 6])
如果我们不按照广播机制,我们可以这样写:
a = np.array([1, 2, 3])
b = np.array([2, 2, 2])
print(a * b)
array([2, 4, 6])
标量值 b 在计算时被伸展为 与 a 一样的形状,伸展后 b 的每一个元素都是原来标量值的复制。实际上,NumPy 并不需要真的复制这些标量值,所以广播运算在内存和计算效率上更高效。
閱讀更多 Python與算法社區 的文章