内容简介:在《在那篇文章中,我还提到了另外一种迁移学习:微调网络,这篇文章就来谈谈微调网络。所谓微调网络,相当于给预训练的模型做一个“换头术”,即“切掉”最后的全连接层(可以想象为卷积神经网络的“头部”),然后接上一个新的参数随机初始化的全连接层,接下来我们在这个动过“手术”的卷积神经网络上用我们比较小的数据集进行训练。
在《 站在巨人的肩膀上:迁移学习 》一文中,我们谈到了一种迁移学习方法:将预训练的卷积神经网络作为特征提取器,然后使用一个标准的机器学习分类模型(比如Logistic回归),以所提取的特征进行训练,得到分类器,这个过程相当于用预训练的网络取代上一代的手工特征提取方法。这种迁移学习方法,在较小的数据集(比如17flowers)上也能取得不错的准确率。
在那篇文章中,我还提到了另外一种迁移学习:微调网络,这篇文章就来谈谈微调网络。
所谓微调网络,相当于给预训练的模型做一个“换头术”,即“切掉”最后的全连接层(可以想象为卷积神经网络的“头部”),然后接上一个新的参数随机初始化的全连接层,接下来我们在这个动过“手术”的卷积神经网络上用我们比较小的数据集进行训练。
在新模型上进行训练,有几点需要注意:
- 开始训练时,“头部”以下的层(也就是没有被替换的网络层)的参数需要固定(frozen),也就是进行前向计算,但反向传递时不更新参数,训练过程只更新新替换上的全连接层的参数。
- 使用一个非常小的学习率进行训练,比如0.001
- 最后,作为可选,在全连接层的参数学习得差不多的时候,我们可以将“头部”以下的层解冻(unfrozen),再整体训练整个网络。
特征提取和微调网络
对照一下上一篇文章中的特征提取,我们以直观的图形来展现它们之间的不同:
如果我们在VGG16预训练模型上进行特征提取,其结构如下图所示:
对比原模型结构,从最后一个卷积池化层直接输出,即特征提取。而微调网络则如下图所示:
通常情况下,新替换的全连接层参数要比原来的全连接层参数要少,因为我们是在比较小的数据集上进行训练。训练过程通常分两个阶段,第一阶段固定卷积层的参数,第二阶段则全放开:
相比特征提取这种迁移学习方法,网络微调通常能得到更高的准确度。但记住,天下没有免费的午餐这个原则,微调网络需要做更多的工作:
首先训练时间很长,相比特征提取只做前向运算,然后训练一个简单的Logisitic回归算法,速度很快,微调网络因为是在很深的网络模型上训练,特别是第二阶段要进行全面的反向传递,耗时更长。在我的GTX 960显卡上用17flowers数据集,训练了几个小时,还没有训练完,结果我睡着了:(,也不知道最终花了多长时间。
其次需要调整的超参数比较多,比如选择多大的学习率,全连接网络设计多少个节点比较合适,这个依赖经验,但在某些特别的数据集上,可能需要多尝试几次才能得到比较好的结果。
网络层及索引
在“动手术”之前,我们需要了解模型的结构,最起码我们需要知道层的名称及索引,没有这些信息,就如同盲人拿起手术刀。在keras中,要了解层的信息非常简单:
print("[INFO] loading network ...") model = VGG16(weights="imagenet", include_top=args["include_top"] > 0) print("[INFO] showing layers ...") for (i, layer) in enumerate(model.layers): print("[INFO] {}\t{}".format(i, layer.__class__.__name__)) 复制代码
VGG16的模型结构如下:
可以看到第20 ~ 22层为全连接层,这也是微调网络要替换的层。
网络“换头术”
首先,我们定义一组全连接层:INPUT => FC => RELU => DO => FC => SOFTMAX。相比VGG16中的全连接层,这个更加简单,参数更少。然后将基本模型的输出作为模型的输入,完成拼接。这个拼接在keras中也相当简单。
class FCHeadNet: @staticmethod def build(base_model, classes, D): # initialize the head model that will be placed on top of # the base, then add a FC layer head_model = base_model.output head_model = Flatten(name="flatten")(head_model) head_model = Dense(D, activation="relu")(head_model) head_model = Dropout(0.5)(head_model) # add a softmax layer head_model = Dense(classes, activation="softmax")(head_model) return head_model 复制代码
因为在VGG16的构造函数中有一个include_top参数,可以决定是否包含头部的全连接层,所以这个“换头”步骤相当简单:
base_model = VGG16(weights="imagenet", include_top=False, input_tensor=Input(shape=(224, 224, 3))) # initialize the new head of the network, a set of FC layers followed by a softmax classifier head_model = FCHeadNet.build(base_model, len(class_names), 256) # place the head FC model on top of the base model -- this will become the actual model we will train model = Model(inputs=base_model.input, outputs=head_model) 复制代码
这时得到的model就是经过“换头术”的网络模型。
训练
微调网络的训练和之前谈到的模型训练过程差不多,只是多了一个freeze层的动作,实际上是进行两个训练过程。如何固定层的参数呢?一句话就可以搞定:
for layer in base_model.layers: layer.trainable = False 复制代码
“解冻”类似,只是layer.trainable值设为True。
为了更快的收敛,尽快的学习到全连接层的参数,在第一阶段建议采用RMSprop优化器。但学习率需要选择一个比较小的值,例如0.001。
在经过一个相当长时间的训练之后,新模型在17flowers数据集上的结果如下:
[INFO] evaluating after fine-tuning ... precision recall f1-score support bluebell 0.95 0.95 0.95 20 buttercup 1.00 0.90 0.95 20 coltsfoot 0.95 0.91 0.93 22 cowslip 0.93 0.87 0.90 15 crocus 1.00 1.00 1.00 23 daffodil 0.92 1.00 0.96 23 daisy 1.00 0.94 0.97 16 dandelion 0.94 0.94 0.94 16 fritillary 1.00 0.95 0.98 21 iris 0.96 0.96 0.96 27 lilyvalley 0.94 0.89 0.91 18 pansy 0.90 0.95 0.92 19 snowdrop 0.86 0.95 0.90 20 sunflower 0.95 1.00 0.98 20 tigerlily 0.96 0.96 0.96 23 tulip 0.70 0.78 0.74 18 windflower 1.00 0.95 0.97 19 micro avg 0.94 0.94 0.94 340 macro avg 0.94 0.93 0.94 340 weighted avg 0.94 0.94 0.94 340 复制代码
相比特征提取这种迁移学习,准确率有了相当可观的提升。
小结
网络微调是一项非常强大的技术,我们无需从头开始训练整个网络。相反,我们可以利用预先存在的网络架构,例如在ImageNet数据集上训练的最先进模型,该模型由丰富的过滤器组成。使用这些过滤器,我们可以“快速启动”我们的学习,使我们能够进行网络手术,最终得到更高精度的迁移学习模型,而不是从头开始训练,而且工作量少。
以上实例均有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。
另外,我在阅读《Deep Learning for Computer Vision with Python》这本书,在微信公众号后台回复“计算机视觉”关键字,可以免费下载这本书的电子版。
以上所述就是小编给大家介绍的《再谈迁移学习:微调网络》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- layui-v2.4.5 兼容性微调
- 微软再次修改 Visual Studio 图标,并微调用户界面
- 微软再次修改 Visual Studio 图标,并微调用户界面
- Qt 5.9 Beta 发布,将微调 Qt 的发布流程
- 连载二:PyCon2018|用slim微调PNASNet模型(附源码)
- 【火炉炼AI】深度学习007-Keras微调进一步提升性能
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。