一个例子了解迁移学习

栏目: Python · 发布时间: 5年前

内容简介:对于传统机器学习而言,要求训练样本与测试样本满足独立同分布,而且必须要有足够多的训练样本。而迁移学习能把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),目标领域往往只有少量有标签样本,使得目标领域能够取得更好的学习效果。这里基于预训练的卷积神经网络训练一组新参数,然后将其用于分类任务,这样就能共享模型参数,避免了从头开始训练模型的参数,大大减少训练时间。在示例中使用flower17数据集,它是一个包含17种花卉类别的数据集,每个类别有80张图像。收集的花都是英国一些常见的花,这些图像具有大比

迁移学习

对于传统机器学习而言,要求训练样本与测试样本满足独立同分布,而且必须要有足够多的训练样本。而迁移学习能把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),目标领域往往只有少量有标签样本,使得目标领域能够取得更好的学习效果。

迁移方式

  • 样本迁移,在源领域中找出与目标领域相似的样本,增加该样本的权重,使其在预测目标与的比重加大。

  • 特征迁移,源领域与目标领域包含共同的交叉特征,通过特征变换将源领域和目标领域的的特征变换到相同空间,使它们具有相同分布。

  • 模型迁移,源领域和目标领域共享模型参数,将源领域已训练好的网络模型应用到目标领域的新问题上。

  • 关系迁移,源领域和目标领域具有某种相似关系,可以将源领域的逻辑关系应用到目标领域中。

模型迁移

这里基于预训练的卷积神经网络训练一组新参数,然后将其用于分类任务,这样就能共享模型参数,避免了从头开始训练模型的参数,大大减少训练时间。

数据集

在示例中使用flower17数据集,它是一个包含17种花卉类别的数据集,每个类别有80张图像。收集的花都是英国一些常见的花,这些图像具有大比例、不同姿态和光线变化等性质。

使用水仙花和款冬这两类花,并且在预训练的VGG16网络之上构建分类器。

一个例子了解迁移学习
image
一个例子了解迁移学习
image

实现

首先导入所有必需的库,包括应用程序、预处理、模型检查点以及相关对象,cv2库和NumPy库用于图像处理和数值的基本操作。

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np

定义输入、数据源及与训练参数相关的所有变量。

img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1

调用VGG16预训练模型,其中不包括顶部的平整化层。冻结不参与训练的层,这里我们冻结前五层,然后添加自定义层,从而创建最终的模型。

model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
    layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)

接着开始编译模型,并为训练、测试数据集创建图像数据增强生成器。

model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
                    metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                   width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                  width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)

生成增强后新的数据,根据情况保存模型。

train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
                             mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')

开始对模型中新的网络层进行拟合。

model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
                          validation_data=validation_generator, nb_val_samples=nb_validation_samples,
                          callbacks=[checkpoint, early])

练完成后用水仙花图像测试这个新模型,输出的正确值应该为接近[1.,0.]的数组。

im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
 1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
 2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
 3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
 4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
 5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
 6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567 
 7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
 8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
 9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5

 9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877  0.77861226]]

github

https://github.com/sea-boat/DeepLearning-Lab/blob/master/transfer_learning.py

--------------------------------------

跟我交流:

一个例子了解迁移学习

-------------推荐阅读------------

我的开源项目汇总(机器&深度学习、NLP、网络IO、AIML、 mysql 协议、chatbot)

为什么写《Tomcat内核设计剖析》

2017文章汇总——机器学习篇

2017文章汇总——Java及中间件

2017文章汇总——深度学习篇

2017文章汇总——JDK源码篇

2017文章汇总——自然语言处理篇

2017文章汇总——Java并发篇


以上所述就是小编给大家介绍的《一个例子了解迁移学习》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

中国制造2025:产业互联网开启新工业革命

中国制造2025:产业互联网开启新工业革命

夏妍娜、赵胜 / 机械工业出版社 / 2016-2-22 / 49.00

过去20年,是中国消费互联网肆意生长的"黄金20年",诞生了诸如BAT等互联网巨头,而时至今日,风口正逐渐转向了产业互联网。互联网这一摧枯拉朽的飓风,在改造了消费服务业之后,正快速而坚定地横扫工业领域,拉开了产业互联网"关键30年"的大幕。 "中国制造2025"规划,恰是中国政府在新一轮产业革命浪潮中做出的积极举措,是在"新常态"和"供给侧改革"的背景下,强调制造业在中国经济中的基础作用,认......一起来看看 《中国制造2025:产业互联网开启新工业革命》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具