内容简介:[MXNet代码剖析] NNVM计算图抽象
设计概要
图的抽象
简而言之就是用 Symbol
同时用来抽象计算图中的 Operator
和 Operand
:
- Variable Symbol
- Functor Symbol(AtomicSymbol), Callable语义
在图的表示中,可以用 Operator
来将一些 Operand
给 Compose
为新的一个 Operand
,也就是图中的节点。最后我们拿到一个 Symbol
,从它就可以回溯出整个图,转换为 Graph
对象来完成建图操作。
不同的 Operator
会有自己的属性,通用的属性定义在 op_attr_types.h
里面,基本上就是不同的类型
Python接口
这里的实现非常漂亮,首先 C API
定义的导出接口就很少,Python代码中仅仅定义最核心的 Symbol
类、运算符重载以及必要的 ctypes
胶水代码。至于如何用 Operator
来建图,则是在C++代码中通过静态定义注册进去,然后在 import nnvm
的时候动态注册进Python。具体实现里面有更多的细节。
实现细节
自动求导
自动求导的原理比较简单,首先为每个操作符定义 Gradient
运算:设当前节点为 \(f()\)
,输入为 y
对当前节点的导数(是一个 Tensor
),那么输出就应该是。注意这个运算过程都是 符号运算
而非数值运算。
类型推导
内存分配
首先贴上架构设计文档: Optimizing Memory Consumption in Deep Learning 。
为计算图的所有节点分配内存的问题可以抽象为:给定一个内存块Request/Free的操作序列,要求满足所有分配的需求,并且使得总的分配内存数量最少。所以这是个NP问题么?
内存分配器会有一个参数 match_range_
,用来表示在 [size/match_range_, size*match_range_]
的范围内来寻找内存块。这里面的 trick
在于先试图分配大的内存块,然后找不到的话再试图分配小的内存块。当然这里并不是真的在分配内存,而只是预先规划我要分配怎样的内存,如果找到小的内存块,肯定不满足我们的需求,我们此时把它放大到我们想要的 size
即可,我们现在是在记录需求,反正最终运行的时候才会真的分配内存。
接下来分析实现细节。整体上分为初始化阶段和按照拓扑序遍历计算图的阶段。
初始化阶段
-
内存分配阶段要依赖于
Shape
和Type
的Inference
,这是显然的,不然分配个毛啊。在注册这个Pass
的时候,会指定这种依赖关系。 -
然后要计算所有非
Variable
节点的出度,作为refcount
;有些操作符具有FIgnoreInputs
属性,并不需要输入数据(只要shape
),比如zeroslike
这样的操作符,所以遍历的时候不要算这部分的引用计数。 - 输出节点要额外加一个引用计数(出度+1),保证在计算图执行到结束的时候也不会回收这些内存。这一点很重要,我就踩过坑。
拓扑序遍历阶段
这一阶段直接是一个 for
循环,以拓扑序遍历整个计算图,循环体内所做的事情如下:
-
首先是检查是否能做
in-place
运算优化。Operator
可以设置自己支持inplace
操作来显式优化内存分配,所以内存分配的时候是先处理能够inplace
的情况,然后再操作正常的内存分配。另外inplace
优化实际上可能是一对多的关系,就是说运算符可以指定一个输入节点的内存可能被复用给多个输出节点
,因为可能有的输出节点只需要shape
信息,不需要数据本身,根本不用给他分配空间。最后,inplace
优化需要满足一个挺复杂的条件:- 输入节点只对应一个输出(出度为1)
- 输出节点有被其他节点引用(否则就不需要为它分配内存,因为根本不用算它)
- 输出节点尚未分配内存
- 输入节点已分配内存(拓扑序遍历的话,这个条件应该是默认满足的)
- 数据类型、大小匹配
- 接下来就开始遍历当前节点的 所有输出 了,把所有还没分配内存的节点记录下来排个序,从小到大依次向内存分配器请求内存即可。
-
然后我们就可以更新引用计数了:把所有 输入节点
(排除
FIgnoreInputs
节点)的refcount - 1
,如果refcount == 0
,就可以释放这个节点的内存。另外这时会遇到有些节点出度本来就是零,这是因为inplace
优化导致的,跳过就行了。 - 最后我们还需要遍历一遍 输出节点 ,把那些出度为零的节点的内存释放掉,标记为 不需要分配内存 ,因为他们根本不被用到,对用户来说处于“不可见状态”。
设备分配
代码风格
-
Operator
的注册机制有点滥用全局状态的感觉,而且有些优化trick
显得意义不大,宏接口设计的倒是比较漂亮。 -
计算图优化的部分采用一个
pass
一个编译单元的结构,代码质量一般,还是觉得在建图的过程中做了一些用处不大的优化trick
,可读性不是很好。 -
Python接口与
ctypes
部分写得很赞,与NNVM_REGISTER_OP
的协作非常漂亮,使得运算符注册能够在import nnvm
的时候自动搞定。就是……有点绕:joy:
以上所述就是小编给大家介绍的《[MXNet代码剖析] NNVM计算图抽象》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- STL — vector源代码剖析
- 【Java集合源码剖析】ArrayList源码剖析
- Java集合源码剖析:TreeMap源码剖析
- 【剖析 | SOFARPC 框架】系列之 SOFARPC 优雅关闭剖析
- 【剖析 | SOFARPC 框架】系列之 SOFARPC 注解支持剖析
- 【剖析 | SOFARPC 框架】系列之 SOFARPC 泛化调用实现剖析
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。