理解并实现 ResNet(Keras)

栏目: 数据库 · 发布时间: 5年前

内容简介:本文为 AI 研习社编译的技术博客,原标题 :Understanding and Coding a ResNet in Keras

理解并实现 ResNet(Keras)

本文为 AI 研习社编译的技术博客,原标题 :

Understanding and Coding a ResNet in Keras

作者 |  Priya Dwivedi @ Deep Learning Analytics

翻译 | linlh、通夜   编辑 | 邓普斯•杰弗、Pita

原文链接:

https://towardsdatascience.com/understanding-and-coding-a-resnet-in-keras-446d7ff84d33

注:本文的相关链接请访问文末【阅读原文】

ResNet 是残差网络(Residual Network)的缩写,是一种作为许多计算机视觉任务主干的经典神经网络。这个模型是2015年ImageNet挑战赛的获胜者,ResNet最根本的突破在于它使得我们可以训练成功非常深的神经网路,如150+层的网络。在ResNet之前,由于梯度消失(vanishing gradients)的问题,训练非常深的神经网络是非常困难的。

AlexNet,2012年ImageNet的获胜者,这个模型就明显开始关注解决仅有8个卷积层的深度学习,VGG网络有19层,Inception或者GoogleNet有22层,ResNet 152有152层。在这篇文章中,我们会编写一个ResNet-50的网络,ResNet 152的小型版本,经常在开始的时候用在迁移学习上。

理解并实现 ResNet(Keras)

深度革命

但是,提升网络的深度并不是简单的将网络层堆叠起来。深层网络很难训练的原因,是因为非常烦人的梯度消失问题——随着梯度反向传播回前面的网络层,重复的乘积操作会使得梯度变得非常小。结果呢,随着网络越来越深,它的性能就变得饱和了,并开始迅速下降。

我是在Andrew Ng的 DeepLearning.AI 课程上学习到关于编写ResNet的内容的,非常推荐大家观看这个课程。

在我的Github repo上,我分享了两个Jupyter Notebook,一个是如DeepLearning.AI中所述,从头开始编码ResNet,另一个在Keras中使用预训练的模型。希望你可以把代码下载下来,并自己试一试。

    残差连接(Skip Connection)——ResNet的强项

ResNet是第一个提出残差连接的概念。下面的图阐述了残差连接。左边的图演示了网络层的堆叠,一层接着一层。在右边的图中,我们仍然看了之前网络层的堆叠,但是我们还将原始的输入添加到卷层单元的输出。

理解并实现 ResNet(Keras)

残差连接示意图 (来自 DeepLearning.AI)

可以写成下面两行代码:

X_shortcut = X # Store the initial value of X in a variable
## Perform convolution + batch norm operations on X

X = Add()([X, X_shortcut]) # SKIP Connection

代码是非常简单,但是这里有一个非常重要的考虑因素——上面的X,X_shortcut是两个矩阵,只有在他们是相同的形状时,你才可以相加。因此,如果卷积+批量规范(batch norm)操作以输出形状相同的方式完成,那么我们可以简单地添加它们,如下所示。

理解并实现 ResNet(Keras)

当 x 和 x_shortcut 是相同的形状

否则,x_shortcut通过选定的卷积层,使得它的输出与卷积块的输出相同,如下所示:

理解并实现 ResNet(Keras)

X_shortcut 通过卷积单元

在Github的Notebook上,identity_block 和convolution_block 两个函数实现了上面的内容。这些函数使用Keras来实现带有ReLU激活函数的Convolution和Batch Norm层。残差连接实现上就是这行代码: X = Add()([X, X_shortcut])。

这里需要注意的一件重要的事情是残差连接是应用在ReLU激活函数之前,正如上图所示。研究人员发现这样可以得到最好的结果。

    为什么要跳过连接?  

这是个有趣的问题。我认为在这里跳过连接有两个原因:

  1. 他们通过允许梯度通过这条可选的捷径来缓解梯度消失的问题

  2. 它们允许模型学习一个恒等函数,该函数确保高层的性能至少与低层一样好,而不是更差。

事实上,由于ResNet跳过连接被用于更多的模型架构中,比如全卷积网络(FCN)和U-Net。它们用于将信息从模型中的较早层传递到较晚层。在这些体系结构中,它们用于将信息从下采样层传递到上采样层。

    测试我们构建的ResNet模型  

然后将笔记本中编码的恒等和卷积块组合起来,创建一个ResNet-50模型,其架构如下:

理解并实现 ResNet(Keras)

ResNet-50模型 

ResNet-50模型由5个阶段组成,每个阶段都有一个卷积和恒等块。每个卷积块有3个卷积层每个单位块也有3个卷积层。ResNet-50有超过2300万个可训练参数。

我已经在我的Github repo中包含的signs数据集上测试了这个模型。这个数据集有对应于6个类的手动图像。我们有1080张火车图像和120张测试图像。

理解并实现 ResNet(Keras)

符号数据集 

我们的ResNet-50经过25个阶段的训练,测试精度达到86%。不错!

    在Keras中用预训练库构建ResNet

我喜欢自己编写ResNet模型,因为它让我更好地理解了我经常在与图像分类,对象定位,分割等相关的许多迁移学习任务中使用的网络。

但是,对于更为常用的做法,在Keras中预训练的ResNet-50模型更快。Keras拥有许多这些骨干模型,其库中提供了Imagenet权重。

理解并实现 ResNet(Keras)

Keras 预训练的模型

我上传了一个Notebook放在Github上,使用的是Keras去加载预训练的模型ResNet-50。你可以用一行的代码来加载这个模型:

base_model = applications.resnet50.ResNet50(weights= None, include_top=False, input_shape= (img_height,img_width,3))

在这里weights=None,因为我想用随机权重初始化模型,就像我在ResNet-50 I编码时所做的那样。或者也可以加载预训练的ImageNet的权重。设置include_top=False,表示不包含原始模型中最后的池化层(pooling)和全连接层(fully connected)。我在ResNet50模型中添加了全局平均池化层(global average pooling)和密集输出层(dense output)。

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.7)(x)
predictions = Dense(num_classes, activation= 'softmax')(x)
model = Model(inputs = base_model.input, outputs = predictions)

从上面的可以看到,Keras提供非常方便的接口去加载预训练模型,但重要的是至少要对ResNet自己编码一次,这样你才能理解这个概念,并且可以将这种学习应用到你正在创建的另一个新架构中。

这个Keras ResNet模型在使用了Adam优化器和0.0001的学习率,训练了100个epoch之后得到75%的正确率。这个正确率比我自己编码的模型要低一些,我想这应该和权重初始化有关。

Keras也提供了非常简单的数据增强(data augmentation)的接口,所以如果有机会,在数据集上试试增强,看看结果能不能得到更好的性能。

   总结

  • ResNet是非常强大的骨干模型(backbone model),经常在许多计算机视觉任务中使用

  • ResNet 使用残差连接(skip connection)将较早的网络层的输出添加到更后面网络层。这有助于缓解梯度消失的问题

  • 你可以使用Keras加载预训练的ResNet-50模型或者使用我分享的代码来自己编写ResNet模型。

我有自己深度学习的咨询工作,喜欢研究有趣的问题。我帮助许多初创公司部署基于AI的创新解决方案。 请访问 http://deeplearninganalytics.org/查看我们。

你也可以在medium上查看我的其他文章:

https://medium.com/@priya.dwivedi

    参考

  • DeepLearning.AI

  • Keras

  • ReNet Paper

想要继续查看该篇文章相关链接和参考文献?

点击底部 【阅读原文】 即可访问:

https://ai.yanxishe.com/page/TextTranslation/1643

滑动查看更多内容

理解并实现 ResNet(Keras)

每天进步一点点

扫码参与每日一题

理解并实现 ResNet(Keras)

今天距离CVPR 2019开幕还有11天

今天距离端午假期还有 2 天 理解并实现 ResNet(Keras)

理解并实现 ResNet(Keras)

扫码查看

行人重识别相关资源大列表

理解并实现 ResNet(Keras)

扫码查看

强化学习自然语言处理资源大列表

<<  滑动查看更多栏目  >>

理解并实现 ResNet(Keras) 点击  阅读原文   ,查看本文更多内容


以上所述就是小编给大家介绍的《理解并实现 ResNet(Keras)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

首席增长官

首席增长官

张溪梦 / 机械工业出版社 / 2017-11-1 / 69.9

增长是企业永恒的主题,是商业的本质。 人口红利和流量红利的窗口期正在关闭,曾经“流量为王”所带来的成功经验正在失效,所造成的思维逻辑和方法论亟待更新。在互联网下半场,企业要如何保持增长?传统企业是否能跟上数字化转型的脚步,找到新兴业务的增长模式?为什么可口可乐公司用首席增长官取代了首席营销官职位? 数据驱动增长正在成为企业发展的必需理念,首席增长官、增长团队和增长黑客将是未来商业的趋势......一起来看看 《首席增长官》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

随机密码生成器
随机密码生成器

多种字符组合密码

MD5 加密
MD5 加密

MD5 加密工具