内容简介:在《对于Google、Facebook来说,收集几百万张图片,训练超大规模的深度学习模型,自然不在话下。但是对于个人或者小型企业而言,收集现实世界的数据,特别是带标签的数据,将是一件非常费时费力的事。本文探讨一种技术,在现有数据集的基础上,进行数据增强(data augmentation),增加参与模型训练的数据量,从而提升模型的性能。所谓数据增强,就是采用在原有数据上随机增加抖动和扰动,从而生成新的训练样本,新样本的标签和原始数据相同。这个也很好理解,对于一张标签为“狗”的图片,做一定的模糊、裁剪、变形等
在《 提高模型性能,你可以尝试这几招… 》一文中,我们给出了几种提高模型性能的方法,但这篇文章是在训练数据集不变的前提下提出的优化方案。其实对于深度学习而言,数据量的多寡通常对模型性能的影响更大,所以扩充数据规模一般情况是一个非常有效的方法。
对于Google、Facebook来说,收集几百万张图片,训练超大规模的深度学习模型,自然不在话下。但是对于个人或者小型企业而言,收集现实世界的数据,特别是带标签的数据,将是一件非常费时费力的事。本文探讨一种技术,在现有数据集的基础上,进行数据增强(data augmentation),增加参与模型训练的数据量,从而提升模型的性能。
什么是数据增强
所谓数据增强,就是采用在原有数据上随机增加抖动和扰动,从而生成新的训练样本,新样本的标签和原始数据相同。这个也很好理解,对于一张标签为“狗”的图片,做一定的模糊、裁剪、变形等处理,并不会改变这张图片的类别。数据增强也不仅局限于图片分类应用,比如有如下图所示的数据,数据满足正态分布:
我们在数据集的基础上,增加一些扰动处理,数据分布如下:
数据就在原来的基础上增加了几倍,但整体上仍然满足正态分布。有人可能会说,这样的出来的模型不是没有原来精确了吗?考虑到现实世界的复杂性,我们采集到的数据很难完全满足正态分布,所以这样增加数据扰动,不仅不会降低模型的精确度,然而增强了泛化能力。
对于图片数据而言,能够做的数据增强的方法有很多,通常的方法是:
-
平移
-
旋转
-
缩放
-
裁剪
-
切变(shearing)
-
水平/垂直翻转
-
…
上面几种方法,可能切变(shearing)比较难以理解,看一张图就明白了:
我们要亲自编写这些数据增强算法吗?通常不需要,比如keras就提供了批量处理图片变形的方法。
keras中的数据增强方法
keras中提供了ImageDataGenerator类,其构造方法如下:
ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization = False, samplewise_std_normalization = False, zca_whitening = False, rotation_range = 0., width_shift_range = 0., height_shift_range = 0., shear_range = 0., zoom_range = 0., channel_shift_range = 0., fill_mode = 'nearest', cval = 0.0, horizontal_flip = False, vertical_flip = False, rescale = None, preprocessing_function = None, data_format = K.image_data_format(), )
参数很多,常用的参数有:
-
rotation_range: 控制随机的度数范围旋转。
-
width_shift_range和height_shift_range: 分别用于水平和垂直移位。
-
zoom_range: 根据[1 - zoom_range,1 + zoom_range]范围均匀将图像“放大”或“缩小”。
-
horizontal_flip:控制是否水平翻转。
完整的参数说明请参考keras文档。
下面一段代码将1张给定的图片扩充为10张,当然你还可以扩充更多:
image = load_img(args["image"]) image = img_to_array(image) image = np.expand_dims(image, axis=0) aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest") aug.fit(image) imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"], save_format="jpeg") total = 0 for image in imageGen: # increment out counter total += 1 if total == 10: break
需要指出的是,上述代码的最后一个迭代是必须的,否在不会在output目录下生成图片,另外output目录必须存在,否则会出现一下错误:
Traceback (most recent call last): File "augmentation_demo.py", line 35, in <module> for image in imageGen: File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__ return self.next(*args, **kwargs) File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in next return self._get_batches_of_transformed_samples(index_array) File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samples img.save(os.path.join(self.save_to_dir, fname)) File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in save fp = builtins.open(filename, "w+b") FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
如下一张狗狗的图片:
经过数据增强技术处理之后,可以得到如下10张形态稍微不同的狗狗的图片,这相当于在原有数据集上增加了10倍的数据,其实我们还可以扩充得最多:
数据增强之后的比较
我们以MiniVGGNet模型为例,说明在其在17flowers数据集上进行训练的效果。17flowers是一个非常小的数据集,包含17中品类的花卉图案,每个品类包含80张图片,这对于深度学习而言,数据量实在是太小了。一般而言,要让深度学习模型有一定的精确度,每个类别的图片至少需要1000~5000张。这样的数据集可以很好的说明数据增强技术的必要性。
从网站上下载的17flowers数据,所有的图片都放在一个目录下,而我们通常训练时的目录结构为:
{类别名}/{图片文件}
为此我写了一个 organize_flowers17.py 脚本。
在没有使用数据增强的情况下,在训练数据集和验证数据集上精度、损失随着训练轮次的变化曲线图:
可以看到,大约经过十几轮的训练,在训练数据集上的准确率很快就达到了接近100%,然而在验证数据集上的准确率却无法再上升,只能达到60%左右。这个图可以明显的看出模型出现了非常严重的过拟合。
如果采用数据增强技术呢?曲线图如下:
从图中可以看到,虽然在训练数据集上的准确率有所下降,但在验证数据集上的准确率有比较明显的提升,说明模型的泛化能力有所增强。
也许在我们看来,准确率从60%多增加到70%,只有10%的提升,并不是什么了不得的成绩。但要考虑到我们采用的数据集样本数量实在是太少,能够达到这样的提升已经是非常难得,在实际项目中,有时为了提升1%的准确率,都会花费不少的功夫。
总结
数据增强技术在一定程度上能够提高模型的泛化能力,减少过拟合,但在实际中,我们如果能够收集到更多真实的数据,还是要尽量使用真实数据。另外,数据增强只需应用于训练数据集,验证集上则不需要,毕竟我们希望在验证集上测试真实数据的准确。
以上实例均有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。
另外,我在阅读《Deep Learning for Computer Vision with Python》这本书,在微信公众号后台回复“计算机视觉”关键字,可以免费下载这本书的电子版。
参考阅读
以上所述就是小编给大家介绍的《使用数据增强技术提升模型泛化能力》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- 深度学习中的Lipschitz约束:泛化与生成模型
- 模型的泛化能力仅和Hessian谱有关吗?
- 深度学习中的Lipschitz约束:泛化与生成模型
- ICLR2020 | 谷歌最新研究:用“复合散度”量化模型合成泛化能力
- 神经网络并不是尚方宝剑,我们需要正视深度 NLP 模型的泛化问题
- 量化深度强化学习算法的泛化能力
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。