内容简介:在MXNet中,其实如果看开源代码数量的话,MXNet已经显得式微,远不如TensorFlow,PyTorch也早已经后来居上。不过据了解,很多公司内部都有基于MXNet自研的框架或平台工具。下面这张图来自LinkedIn上的一个
在MXNet中, Module
提供了训练模型的方便接口。使用 symbol
将计算图建好之后,用 Module
包装一下,就可以通过 fit()
方法对其进行训练。当然,官方提供的接口一般只适合用来训练分类任务,如果是其他任务(如detection, segmentation等),单纯使用 fit()
接口就不太合适。这里把 fit()
代码梳理一下,也是为了后续方便在其基础上实现扩展,更好地用在自己的任务。
其实如果看开源代码数量的话,MXNet已经显得式微,远不如TensorFlow,PyTorch也早已经后来居上。不过据了解,很多公司内部都有基于MXNet自研的框架或平台工具。下面这张图来自LinkedIn上的一个 Slide分享 ,姑且把它贴在下面,算是当前流行框架的一个比较(应该可以把Torch换成PyTorch)。
准备工作
首先,需要将数据绑定到计算图上,并初始化模型的参数,并初始化求解器。这些是求解模型必不可少的。
其次,还会建立训练的metric,方便我们掌握训练进程和当前模型在训练任务的表现。
这些是在为后续迭代进行梯度下降更新做准备。
迭代更新
使用SGD进行训练的时候,我们需要不停地从数据迭代器中获取包含data和label的batch,并将其feed到网络模型中。进行forward computing后进行bp,获得梯度,并根据具体的优化方法(SGD, SGD with momentum, RMSprop等)进行参数更新。
这部分可以抽成:
# in an epoch while not end_epoch: batch = next(train_iter) m.forward_backward(batch) m.update() try: next_batch = next(data_iter) m.prepare(next_batch) except StopIteration: end_epoch = True
metric
在训练的时候,观察输出的各种metric是必不可少的。我们对训练过程的把握就是通过metric给出的信息。通常在分类任务中常用到的metric有Accuracy,TopK-Accuracy以及交叉熵损失等,这些已经在MXNet中有了现成的实现。而在 fit
中,调用了 m.update_metric(eval_metric, data_batch.label)
实现。这里的 eval_metric
就是我们指定的metric,而 label
是batch提供的label。注意,在MXNet中,label一般都是以 list
的形式给出(对应于多任务学习),也就是说这里的label是 list of NDArray
。当自己魔改的时候要注意。
logging
计算了eval_metric等信息,我们需要将其在屏幕上打印出来。MXNet中可以通过callback实现。另外,保存模型checkpoint这样的功能也是通过callback实现的。一种常用的场景是每过若干个batch,做一次logging,打印当前的metric信息,如交叉熵损失降到多少了,准确率提高到多少了等。MXNet会将以下信息打包成 BatchEndParam
类型(其实是一个自定义的 namedtuple
)的变量,包括当前epoch,当前迭代次数,评估的metric。如果你需要更多的信息或者更自由的logging监控,也可以参考代码自己实现。
我们以常用的 Speedometer
看一下如何使用这些信息,其功能如下,将训练的速度和metric打印出来。
Logs training speed and evaluation metrics periodically
PS:这里有个隐藏的坑。MXNet中的 Speedometer
每回调一次,会把 metric
的内容清除。这在训练的时候当然没问题。但是如果是在validation上跑,就会有问题了。这样最终得到的只是最后一个回调周期那些batch的metric,而不是整个验证集上的。如果在 fit
方法中传入了 eval_batch_end_callback
参数就要注意这个问题了。解决办法一是在 Speedometer
实例初始化时传入 auto_reset=False
,另一种干脆就不要加这个参数,默认为 None
好了。同样的问题也发生在调用 Module.score()
方法来获取模型在验证集上metric的时候。
可以在 Speedometer
代码中寻找下面这几行,会更清楚:
if param.eval_metric is not None: name_value = param.eval_metric.get_name_value() if self.auto_reset: param.eval_metric.reset()
在验证集上测试
当在训练集上跑过一个epoch后,如果提供了验证集的迭代器,会在验证集上对模型进行测试。这里,MXNet直接封装了 score()
方法。在 score
中,基本流程和 fit()
相同,只是我们只需要forward computing即可。
附
用了一段时间的MXNet,给我的最大的感觉是MXNet就像一个写计算图的前端,提供了很方便的 python 接口生成静态图,以及很多“可插拔”的插件(虽然可能不是很全,更像是一份guide而不是拿来即用的tool),如上文中的metric等,使其更适合做成流程化的基础DL平台,供给更上层方便地配置使用。缺点就是隐藏了比较多的实现细节(当然,你完全可以从代码中自己学习,比如从 fit()
代码了解神经网络的大致训练流程)。至于MXNet宣扬的诸如速度快,图优化,省计算资源等优点,因为我没有过数据对比,就不说了。
缺点就是写图的时候有时不太灵活(可能也是我写的看的还比较少),即使是和TensorFlow这种同为静态图的DL框架比。另外,貌似MXNet中很多东西都没有跟上最新的论文等,比如Cosine的learning rate decay就没有。Model Zoo也比较少(gluon可能会好一点,Gluon-CV和Gluon-NLP貌似是在搞一些论文复现的工作)。对开发来讲,很多东西都需要阅读代码才能知道是怎么回事,只是读文档的话容易踩坑。
说到这里,感觉MXNet的python训练接口(包括module,optimizer,metric等)更像是一份example代码,是在教你怎么去用MXNet,而不像一个灵活地强大的 工具 箱。当然,很多东西不能得兼,希望MXNet越来越好。
以上所述就是小编给大家介绍的《MXNet fit介绍》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- ASP.NET Core模块化前后端分离快速开发框架介绍之3、数据访问模块介绍
- 简编漫画介绍WebAssembly
- CGroup 介绍
- CGroup 介绍
- vue初步介绍
- Microbit MicroPython 介绍
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
神经网络与机器学习(原书第3版)
[加] Simon Haykin / 申富饶、徐烨、郑俊、晁静 / 机械工业出版社 / 2011-3 / 79.00元
神经网络是计算智能和机器学习的重要分支,在诸多领域都取得了很大的成功。在众多神经网络著作中,影响最为广泛的是Simon Haykin的《神经网络原理》(第3版更名为《神经网络与机器学习》)。在本书中,作者结合近年来神经网络和机器学习的最新进展,从理论和实际应用出发,全面、系统地介绍了神经网络的基本模型、方法和技术,并将神经网络和机器学习有机地结合在一起。 本书不但注重对数学分析方法和理论的探......一起来看看 《神经网络与机器学习(原书第3版)》 这本书的介绍吧!