提高模型准确率:组合模型

栏目: 数据库 · 发布时间: 5年前

内容简介:各位朋友,新年好! 随着春节假期的结束,想必大家陆陆续续返回工作岗位,开始新的一年的拼搏。我也会继续努力,争取在深度学习方面更进一步,接下来,我将继续聊一聊深度学习在计算机视觉中的应用。在前面的《从字面上理解,组合模型并不难解释,简单说,就是为深度学习建立多个模型,然后用多个模型来预测,采取投票或平均法来决定最后的预测结果。稍微想一想,似乎比较好理解,俗话说,

各位朋友,新年好! 随着春节假期的结束,想必大家陆陆续续返回工作岗位,开始新的一年的拼搏。我也会继续努力,争取在深度学习方面更进一步,接下来,我将继续聊一聊深度学习在计算机视觉中的应用。

在前面的《 站在巨人的肩膀上:迁移学习 》和《 再谈迁移学习:微调网络 》两篇文章中,我们介绍了迁移学习的强大之处。然而,人们探索新知识总是永无止境,在提高深度学习模型准确率方面,仍在孜孜不倦的追求着。这篇文章将介绍一种提升模型准确率的方法:组合模型。

从字面上理解,组合模型并不难解释,简单说,就是为深度学习建立多个模型,然后用多个模型来预测,采取投票或平均法来决定最后的预测结果。稍微想一想,似乎比较好理解,俗话说, 三个臭皮匠,顶个诸葛亮 。多个模型投票的结果,应该好于单个模型的准确率。当然,机器学习看起来有些不靠谱(拿概率说事),但还是建立在严密的理论基础之上,组合模型提高准确率如果仅仅建立在一条谚语之上,不足以说服人,也没办法让人接受。

事实上,组合模型是建立在一个称为 琴生不等式 (Jensen’s inequality)之上,该公式以丹麦数学家 约翰·琴生 (Johan Jensen)命名,给出了积分的凸函数值和凸函数的积分值间的关系。有兴趣的同学可以去Google一下,看看这个公式有何神奇之处,我也找了一些资料,然而…没看懂,只了解其大意是说,可能某个的模型的误差低于所有模型的平均值,但由于我们没有可以用来“选择”此模型的标准,所以我们可以确信所有模型的平均值不会比随机选择一个模型差。是不是还是有些晕乎?嗯,这个不重要,我们用实践来检验一下是不是有效吧。

接下来,我们就要准备训练多个机器学习模型。我们也不用把问题复杂化,设计多种网络结构的模型,最简单的方法是,采用相同的网络结构,甚至使用相同的超参数,但训练出不同的参数。闲话少说,直接上代码:

((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0

lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

labelNames = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "trunk"]

aug = ImageDataGenerator(rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
                         fill_mode="nearest")

for i in np.arange(0, args["num"]):
  print("[INFO] training model {}/{}".format(i + 1, args["num"]))
  opt = SGD(lr=0.01, decay=0.01/40, momentum=0.9, nesterov=True)
  model = MiniVGGNet.build(width=32, height=32, depth=3, classes=10)
  model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
  H = model.fit_generator(aug.flow(trainX, trainY, batch_size=64), validation_data=(testX, testY), epochs=40,
                          steps_per_epoch=len(trainX)//64, verbose=1)
  p = [args["models"], "model_{}.model".format(i)]
  model.save(os.path.sep.join(p))

代码比较容易理解,采用cifar10数据集训练,10种类别标签,对输入数据进行了数据扩充(data augmentation),这个数据扩充是随机实时进行,加上训练数据集和验证数据集也是随机划分,所以最后训练出的网络参数有所不同,训练完成之后,将模型序列化到文件,供后面使用。循环num遍,就产生了num个模型。

接下来就是依次加载个模型文件,每个模型分别进行预测,然后取均值:

(testX, testY) = cifar10.load_data()[1]
testX = testX.astype("float") / 255.0
labelNames = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "trunk"]
lb = LabelBinarizer()
testY = lb.fit_transform(testY)

model_paths = os.path.sep.join(args["models"], "*.model")
model_paths = list(glob.glob(model_paths))
models = []

for (i, model_path) in enumerate(model_paths):
  print("[INFO] loading model {}/{}".format(i + 1, len(model_paths)))
  models.append(load_model(model_path))

print("[INFO] evaluating ensemble ...")
predictions = []
for model in models:
  predictions.append(model.predict(testX, batch_size=64))

predictions = np.average(predictions, axis=0)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=labelNames))

上面的代码,代码量相对于单个模型而言,并没有增加多少,只是多了一个循环,多了一个取平均值,但训练过程却多了num倍,我用电脑训练5个模型,结果花了一晚上还没有训练完:(

最后测试的结果如何呢?通过组合多个网络的输出,成功将准确度从83%提高到84%,即使这些网络使用完全相同的超参数在同一数据集上进行训练。有数据表明,采用组合模型,通常准确度有1-5%的提升。

看到这儿,你可能会有些失望,费了这么大的劲,好像也没啥提升,但是别忘了,在医疗领域、自动驾驶领域,即使费上好大的力气,准确率能够提升小数点后面几位,都是值得的。就像每年度的kaggle竞赛,人们依然在孜孜不倦的追求着准确率的提升。

以上实例均有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。

另外,我在阅读《Deep Learning for Computer Vision with Python》这本书,在微信公众号后台回复“计算机视觉”关键字,可以免费下载这本书的电子版。

往期回顾


以上所述就是小编给大家介绍的《提高模型准确率:组合模型》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

测试驱动开发的艺术

测试驱动开发的艺术

Lasse Koskela / 李贝 / 人民邮电出版社 / 20101023 / 59.00元

在传统的软件开发中,开发人员对于代码是否正确心中无底,一切依赖于后期的测试环节。极限编程反其道而行之,主张采用测试驱动开发(TDD)的方法,即通过测试定义所要开发的功能的接口,然后实现功能的开发过程。TDD通过不断地测试推动代码的开发,既简化了代码,又保证了软件质量。 本书采用“手把手”的教学方式,通过大量实例来解释TDD,还专门用几章的篇幅来讲解如何为难于测试的技术编写单元测试。全书内容循......一起来看看 《测试驱动开发的艺术》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具