内容简介:我们在去年12月发布了Deep Graph Library (DGL)的首个公开版本。在过去的几个版本的更新中,DGL主要注重框架的易用性,比如怎样设计一系列灵活易用的接口,如何便于大家实现各式各样的图神经网络(GNN)模型,以及怎样和主流深度学习框架(如PyTorch,MXNet等)集成。因为这些设计,让DGL快速地获得了社区的认可和接受。然而天下没有免费的午餐,不同的框架对于相同的运算支持程度不同,并且普遍缺乏图层面上的计算原语,导致了计算速度上的不足。随着DGL接口的逐渐稳定,我们终于可以腾出手来解决
我们在去年12月发布了Deep Graph Library (DGL)的首个公开版本。在过去的几个版本的更新中,DGL主要注重框架的易用性,比如怎样设计一系列灵活易用的接口,如何便于大家实现各式各样的图神经网络(GNN)模型,以及怎样和主流深度学习框架(如PyTorch,MXNet等)集成。因为这些设计,让DGL快速地获得了社区的认可和接受。然而天下没有免费的午餐,不同的框架对于相同的运算支持程度不同,并且普遍缺乏图层面上的计算原语,导致了计算速度上的不足。随着DGL接口的逐渐稳定,我们终于可以腾出手来解决性能问题。即将发布的DGL v0.3版本中,性能问题将得到全面而系统地改善。
相比当前的DGL稳定版本v0.2,DGL v0.3在性能上取得了显著提升。相比v0.2, DGL v0.3训练速度提高了19倍,并且大幅度降低了内存使用量,使得单GPU上能训练的图的大小提高到原来的8倍。比起PyG等其他框架,DGL不但训练更快,而且能够在巨大的图上(5亿节点,250亿边)训练图神经网络。
接下来,我们将介绍DGL v0.3的重要特性之一 — 消息融合(Fused Message Passing)。我们会逐一解释,为什么普通的消息传递无法拓展到大图上以及消息融合是怎么解决这一问题的。更多细节可以参考我们被 ICLR’19 的 RLGM workshop 所收录的论文[1]。
大图训练的性能瓶颈
绝大多数图神经网络模型遵循消息传递的计算范式,用户需要提供两个函数:
-
消息函数:在边上触发,定义了如何计算发送给相邻节点的消息。
-
累和函数:在点上触发,定义了如果在点上累和收到的消息。
下图中,用户自定义的消息函数用 表示。消息函数将点 i 和 j 上的特征 , 以及边i->j上的特征 作为输入,生成边上的消息(黄色方框)。在每个节点上,用户定义的累和函数将消息累和,然后调用另一个用户定义的更新函数 更新节点的特征。
普通的消息传递很容易在DGL中实现:首先,我们通过 send 接口调用消息函数,然后通过recv 接口调用累和函数。下面的例子实现了目前流行的图卷积网络Graph Convolution Network(GCN)。
# 使用自定义消息函数和累和函数计算图卷积 G.update_all(lambda edges: {'m' : edges.src['h']}, lambda nodes: {'h' : sum(nodes.mailbox['m'], axis=1)})
以上的代码非常简洁易懂,但性能却不佳。原因在于消息传递的过程中实际生成了消息张量(message tensor)。消息张量的大小正比于图中边的数量,因而当图增大时,消息张量消耗的内存空间也会显著上升。以 GraphSage 论文中的 Reddit 数据集(23.2万节点,1.14亿边)为例,如果我们用上述代码训练 GCN ,点上的特征会被拷贝成边上的信息,这会导致内存使用量骤增500倍。除了浪费内存,该做法还使得访存变得更为频繁,进而导致 GPU 的利用率降低。
消息融合解决大图训练难题
为了避免生成消息张量带来的额外开销,DGL实现了消息融合技术。DGL将 send 和 recv 接口合并成 send_and_recv(见下图)。DGL的后端通过自己的CUDA代码,在每个GPU线程中将源节点特征载入其本地内存并计算消息函数,然后将计算结果直接累和到目标节点,从而避免生成消息张量。
为实现消息融合,DGL提供了一系列预先定义好的内建函数。尽管这限制了用户对消息函数和累和函数的选择,但DGL提供了非常丰富的内建函数以实现绝大多数GNN模型。当然,用户也可以选择自己定义消息函数和累和函数,这种情况下,DGL不会进行消息融合优化。
另外在 反向传播 中,由于消息张量没有保存,因此需要被重新计算。实际操作中,许多消息函数的求导都不需要使用到消息张量(比如拷贝源节点特征到边上),而我们的实现也利用了这一特性。
在DGL中使用消息融合
使用消息融合非常简单。比如,我们可以用copy_src内建消息函数和sum内建累和函数改写先前的GCN实现:
import dgl.function as fn G = ... # 任意图结构 # 将源节点的特征h拷贝为消息,并在目标节点累和生成新的特征h。 G.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
图注意力模型 Graph Attention Network (GAT) 则可以用 src_mul_edge 内建消息函数和 sum内建累和函数组合实现:
# 这里假设注意力分数为边上特征e G.update_all(fn.src_mul_edge('h', 'e', 'm'), fn.sum('m', 'h'))
DGL v0.3 将支持以下内建函数:
-
消息函数可以是从源节点特征、边特征、目标节点特征三者中选任意两个进行加、减、乘、除运算。
-
DGL支持特征维度上的广播语义(broadcasting semantics)。这在多头注意力模块中非常常见。
-
累和函数可以是sum, max, min, prod。
我们推荐用户尽可能多的使用DGL的内建函数来定义 图神经网络 ,这样DGL可以利用消息融合来提高性能。虽然这在上手上会有些门槛,但它对性能的提升是非常显著的(详见下一章节)。
性能测试
为了理解消息传递融合带来的性能提升,我们对DGL v0.3和DGL v0.2以及PyG(Pytorch Geometric v1.2.0)进行比较。其中PyG使用了普通的消息传递实现,因此在整个过程中会生成消息张量。
我们首先在主流的数据集上测试了GCN和GAT模型的性能,所有的实验使用了模型论文中的参数设定。实验在AWSp3.2xlarge instance上进行,该机器配备有NVIDIA V100 GPU (16GB 显存)。
从表中可见,即将发布的DGL v0.3在性能上有显著提升,尤其在GAT模型上,训练速度提升了19倍,而这都是因为使用了消息融合技术。在小图上(比如Cora,CiteSeer和PubMed),训练的计算量和内存使用量几乎不随图的大小发生变化,和PyG相比,DGL有微小且固定的额外开销。然而,当在相对较大的图上(比如从Reddit抽取出来的图)训练时,PyG很快便耗尽了内存,而DGL则可以轻松地将数据存储在GPU上进行计算。
我们使用合成的图进一步分析DGL的性能:
我们首先固定图的密度(0.0008),通过调节图的节点数来观察GCN和GAT的训练速度。从图中可见,DGL可以在多达50万节点的图上训练GCN模型,比PyG的最大容量高出一倍。此外,DGL的训练速度比PyG快了3.4倍。
然后我们固定图的节点数,通过调节图的密度来观察训练速度。对GCN和GAT模型,相较PyG,DGL可以支持8倍多的边,并且训练快7.5倍。
我们还在一个中等大小的图上(3.2万节点,密度0.0008)通过调节隐含层的大小来观察训练速度。对于GCN模型,尽管PyG能够支1024个隐含单元,但其训练速度比DGL慢了4倍。对于GAT模型,PyG最多只能支持32个隐含单元,而DGL可以支持到256个。
最后,我们想测试DGL的性能极限,了解DGL在单机情况下能够支持的最大的图的规模。我们在AWSx1.32xlarge (2TB 内存)上用CPU训练GCN。实验表明,DGL可以支持到5亿节点250亿边的图。
接下来期待什么
DGL团队正在积极开发其设计路线图上的功能特性。实际上,DGL项目开始之初,团队成员就考虑到了绝大多数性能优化,比如,DGL一直提倡使用其内建函数而非自定义函数,尽管内建函数只有在消息融合时才能发挥出其优势。以下是DGL团队正在积极拓展的方向:
-
撰写更详细的博客介绍如何在算力强大的CPU机器上复现大图的实验结果。
-
支持异构图结构。
-
用GPU加速图上的遍历和访问。
DGL一直努力接近用户和社区,并且渴望得到用户的反馈。如果您想要尽早尝试即将在v0.3版本发布的新特性,可以克隆DGL的GitHub仓库,切换到kernel分支,然后从源代码编译DGL项目。
1.https://rlgm.github.io/papers/49.pdf
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:- 神经网络 – 序列预测LSTM神经网络落后
- 聊聊从脑神经到神经网络
- 神经网络历史以及浅析神经网络与感知机
- 【神经网络】11行Python代码实现的神经网络
- 常见的五种神经网络(三):循环神经网络(上篇)
- 常见的五种神经网络(三):循环神经网络(中篇)
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
解决网页设计一定会遇到的210个问题
2006-4 / 42.00元
如何选择适合、简单、方便、快速的方法来解决您的网页设计问题?不会HTML、JavaScript、CSS也可轻易完成许多网页功能与特效。本书包含上百种HTML、JavaScript、CSS使用应用技巧与盲点解说,包含10个常用表单资料判断函数与特殊技巧,不必修改就可用于任何网页。本书现有的多数网页设计书籍相辅相成,让您事半功倍地完成工作。 许多计算机书籍都是从某个语言或者某个软件的......一起来看看 《解决网页设计一定会遇到的210个问题》 这本书的介绍吧!