揭秘 NumPy 的形

使用 NumPy, TensorFlow, Pytorch ,我们经常会使用数组的 reshape 操作,变化数组为各种 shape.

一个一维数组,长度为 12,为什么能变化为二维 (12,1) 或 (2,6) 等,三维 (12,1,1) 或 (2,3,2) 等,四维 (12,1,1,1) 或 (2,3,1,2) 等,总之,可以变化为任意多维度。

reshape 是如何做到的?使用了什么魔法数据结构和算法吗?

NumPy 作为数据分析和深度学习领域的必备基础库,数值计算效率666得飞起。今天我们就以 NumPy 数组的 reshape 方法为例,一探究竟数据的这种 reshape 变化及背后的实现原理。

这篇文章对于 reshape 方法的原理解释,会很独到,尽可能让朋友们弄明白数组 reshape 的魔法。

如同往常一样,导入 NumPy 包:

<code>import numpy as np
/<code>

创建一个一维数组 a ,从 0 开始,间隔为 2 ,含有 12 个元素的数组:

<code>a = np.arange(0,24,2)
/<code>

打印数组 a

<code>In [48]: a
Out[48]: array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22])
/<code>

如上数组 a, NumPy 会将其解读成两个结构,一个 buffer ,还有一个 view 。

buffer 的示意图如下所示:

view 是解释 buffer 的一个结构,比如数据类型, flags 信息等:

<code>In [50]: a.dtype
Out[50]: dtype('int32')

In [51]: a.flags
Out[51]:
C_CONTIGUOUS : True
F_CONTIGUOUS : True
OWNDATA : True
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False
/<code>

使用 a[6] 访问数组 a 中 index 为 6 的元素。从背后实现看, NumPy 会辅助一个轴,轴的取值为 0 到 11 。

从概念上看,它的示意图如下所示:

所以,借助这个轴 i , a[6] 就会被索引到元素 12,如下所示:

至此,大家要建立一个轴的概念。

接下来,做一次 reshape 变化,变化数组 a 的 shape 为 (2,6):

<code>b = a.reshape(2,6)
/<code>

打印 b:

<code>In [53]: b
Out[53]:
array([[ 0, 2, 4, 6, 8, 10],
[12, 14, 16, 18, 20, 22]])
/<code>

此时,NumPy 会建立两个轴,假设为 i , j , i 的取值为 0 到 1, j 的取值为 0 到 5,示意图如下:

使用 b[1][2] 获取元素到 16

<code>In [54]: b[1][2]
Out[54]: 16
/<code>

两个轴的取值分为 1,2,如下图所示,定位到元素 16

平时,有些读者朋友可能会混淆两个 shape,(12,) 和 (12,1) ,其实前者一个轴,后者两个轴,示意图分别如下:

一个轴,取值从 0 到 11

两个轴,i 轴取值从 0 到 11, j 轴取值从 0 到 0

至此,大家要建立两个轴的概念。

并且,通过上面几幅图看到,无论 shape 如何变化,变化的是轴或称作 view,底下的 buffer 始终未变。

接下来,上升到三个轴,变化数组 a 的 shape 为 (2,3,2) :

<code>c = a.reshape(2,3,2)
/<code>

打印 c:

<code>In [55]: c = a.reshape(2,3,2)

In [56]: c
Out[56]:
array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],

[[12, 14],
[16, 18],
[20, 22]]])
/<code>

数组 c 有三个轴,取值分别为 0 到 1, 0 到 2, 0 到 1,示意图如下所示:

读者们注意体会, i , j , k 三个轴,其值的分布规律。如果去掉 i 轴取值为 1 的单元格后,

实际就对应到数组 c 的前半部分元素:

<code>array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],
/<code>

也就是如下的索引组合 :

至此,三个轴的 reshape 已经讲完。

最后,说一个有意思的问题。

还记得,原始的一维数组 a 吗?它一共有 12 个元素,后来,我们变化它为数组 c ,shape 为 (2,3,2),那么如何升级为 4 维或 任意维呢?

4 维可以为:(1,2,3,2),示意图如下:

看到,轴 i 索引取值只有 0,它被称为自由维度,可以任意插入到原数组的任意轴间。

比如,5 维可以为:(1,2,1,3,2):

至此,你应该完全理解 reshape 操作后的魔法:

buffer 是个一维数组,永远不变;变化的 shape 通过 view 传达;取值仅有 0 的自由轴,能变化出任意维度。

最后补充一点:reshape 操作返回的对象仅仅是原数组的视图,是一个引用,并未发生复制操作,因此 reshape 一个高效的操作。

要想学到别人可能没掌握的本领,只有静下心来,踏实下来。通过日复一日的训练,才能到达理想的彼岸,遇见更好的自己。