探索 YOLO v3 实现 - 训练 1

栏目: 编程工具 · 发布时间: 6年前

内容简介:YOLO,即You Only Look Once(你只看一次)的缩写,是一个基于卷积神经网络(CNN)的物体检测算法。而YOLO v3是YOLO的第3个版本(即YOLO、YOLO是一句美国的俗语,You Only Live Once,人生苦短,及时行乐。本文介绍如何实现YOLO v3算法。这是第1篇,训练,当然还有第2篇。

YOLO,即You Only Look Once(你只看一次)的缩写,是一个基于卷积神经网络(CNN)的物体检测算法。而YOLO v3是YOLO的第3个版本(即YOLO、 YOLO 9000 、YOLO v3),检测效果,更准更强。

YOLO是一句美国的俗语,You Only Live Once,人生苦短,及时行乐。

本文介绍如何实现YOLO v3算法。这是第1篇,训练,当然还有第2篇。

1. 参数

模型的训练参数,5个参数:

(1) 已标注框的图片数据集,格式如下:

图片的位置 框的4个坐标(xmin,ymin,xmax,ymax,label_id)...
dataset/image.jpg 788,351,832,426,0 805,208,855,270,0
复制代码

(2) 标注框类别的汇总,即数据集所标注物体的类别列表,如下:

aeroplane
bicycle
bird
...
复制代码

(3) 预训练模型,用于迁移学习(Transfer Learning)中的微调(Fine Tune),支持使用已训练完成的COCO模型参数,即:

pretrained_path = 'model_data/yolo_weights.h5'
复制代码

(4) 预测特征图(Prediction Feature Map)的anchor框(anchor box)集合:

  • 3个尺度(scale)的特征图,每个特征图3个anchor框,共9个框,从小到大排列;
  • 1~3是大尺度(52x52)特征图所使用的,4~6是中尺度(26x26),7~9是小尺度(13x13);
  • 大尺度特征图检测小物体,小尺度检测大物体;
  • 9个anchor来源于边界框(Bounding Box)的k-means聚类。

其中,COCO的anchors,如下:

10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
复制代码

(5) 图片输入尺寸,默认为416x416。

  • 图片尺寸满足 32 的倍数,在DarkNet网络中,含有5次步长为2的降采样卷积( 32=2^5 ),其中,卷积操作如下:
x = DarknetConv2D_BN_Leaky(num_filters, (3, 3), strides=(2, 2))(x)
复制代码
  • 在最底层时,特征图尺寸需要满足为 奇数 ,如13,以保证中心点落在唯一框中。当为偶数时,则导致中心点落在中心的4个框中。

2. 创建模型

创建YOLOv3的网络模型,输入:

  • input_shape:图片尺寸;
  • anchors:9个anchor box;
  • num_classes:类别数;
  • freeze_body:网络冻结模式,1是冻结DarkNet53模型,2是全部冻结保留最后3层;
  • weights_path:预训练模型的权重

实现:

model = create_model(input_shape, anchors, num_classes,
                     freeze_body=2,
                     weights_path=pretrained_path)
复制代码

冻结:网络的最后3层,通过1x1的卷积,输出3个尺度的预测值。

实现:

out_filters = num_anchors * (num_classes + 5)
// ...
DarknetConv2D(out_filters, (1, 1)
复制代码

即:

conv2d_59 (Conv2D)      (None, 13, 13, 18)   18450       leaky_re_lu_58[0][0]    
conv2d_67 (Conv2D)      (None, 26, 26, 18)   9234        leaky_re_lu_65[0][0]    
conv2d_75 (Conv2D)      (None, 52, 52, 18)   4626        leaky_re_lu_72[0][0]    
复制代码

3. 样本数量

样本洗牌(shuffle),将数据集拆分为10份,训练9份,验证1份。

实现:

val_split = 0.1  # 训练和验证的比例
with open(annotation_path) as f:
    lines = f.readlines()
np.random.seed(47)
np.random.shuffle(lines)
np.random.seed(None)
num_val = int(len(lines) * val_split)  # 验证集数量
num_train = len(lines) - num_val  # 训练集数量
复制代码

4. 第1阶段训练

第1阶段,冻结部分网络,训练底层参数。

  • 优化器使用常见的Adam;
  • 损失函数,直接使用,模型的输出 y_pred ,忽略真值 y_true

实现:

model.compile(optimizer=Adam(lr=1e-3), loss={
    # 使用定制的 yolo_loss Lambda层
    'yolo_loss': lambda y_true, y_pred: y_pred})  # 损失函数
复制代码

对于损失函数 yolo_loss ,以及 y_truey_pred

y_true 当成一个输入,构成多输入模型,把loss写成层(Lambda层),作为最后的输出。这样,构建模型的时候,就只需要将模型的输出(output)定义为loss即可。而编译(compile)的时候,直接将loss设置为 y_pred ,因为模型的输出就是loss,即 y_pred 就是loss,因而无视 y_true 。训练的时候,随便添加一个符合形状的 y_true 数组即可。

关于 Python 的Lambda表达式:

f = lambda y_true, y_pred: y_pred
print(f(1, 2))  # 输出2
复制代码

模型fit数据,使用数据生成包装器( data_generator_wrapper ),按批次生成训练和验证数据。最终,模型model存储权重。实现如下:

batch_size = 32  # batch
model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),
                    steps_per_epoch=max(1, num_train // batch_size),
                    validation_data=data_generator_wrapper(
                        lines[num_train:], batch_size, input_shape, anchors, num_classes),
                    validation_steps=max(1, num_val // batch_size),
                    epochs=50,
                    initial_epoch=0,
                    callbacks=[logging, checkpoint])
# 存储最终的参数,再训练过程中,也通过回调存储
model.save_weights(log_dir + 'trained_weights_stage_1.h5')  
复制代码

在训练过程中,也会存储模型的参数,只存储权重( save_weights_only ),只存储最优结果( save_best_only ),每隔3个epoch存储一次( period ),即:

checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
                             monitor='val_loss', save_weights_only=True,
                             save_best_only=True, period=3)  # 只存储weights权重
复制代码

5. 第2阶段训练

第2阶段,使用第1阶段已训练的网络权重,继续训练:

  • 将全部的参数都设置为可训练,而第1阶段则是冻结(freeze)部分参数;
  • 优化器,仍是Adam,只是学习率(lr)有所下降,从1e-3减少至1e-4,细腻地学习最优参数;
  • 损失函数,仍是只使用 y_pred ,忽略 y_true

实现:

for i in range(len(model.layers)):
    model.layers[i].trainable = True

model.compile(optimizer=Adam(lr=1e-4),
              loss={'yolo_loss': lambda y_true, y_pred: y_pred})
复制代码

第2阶段的模型fit数据,与第1阶段类似,从第50个epoch开始,一直训练到第100个epoch,触发条件则提前终止。额外增加了两个回调 reduce_lrearly_stopping

  • reduce_lr :当评价指标不在提升时,减少学习率,每次减少10%(factor),当学习率3次未减少(patience)时,终止训练。
  • early_stopping :验证集准确率,连续增加小于0( min_delta )时,持续10个epoch( patience ),则终止训练。

实现:

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)  # 当评价指标不在提升时,减少学习率
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)  # 验证集准确率,下降前终止

batch_size = 32
model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),
                    steps_per_epoch=max(1, num_train // batch_size),
                    validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors,
                                                           num_classes),
                    validation_steps=max(1, num_val // batch_size),
                    epochs=100,
                    initial_epoch=50,
                    callbacks=[logging, checkpoint, reduce_lr, early_stopping])
model.save_weights(log_dir + 'trained_weights_final.h5')
复制代码

至此,在第2阶段训练完成之后,输出的网络参数,就是最终的模型参数。

补充1. K-Means

K-Means算法是聚类算法,将一组数据划分为多个组(group),每个组都含有一个中心。

YOLOv3,获取数据集中的anchor box,与这个k-means实例类似,即将各种框聚类为9类,获取9个聚类中心,作为9个anchor box,从小到大排列。

模拟K-Means算法:

  1. 创建测试点,X是数据,y是标签,如X:(300,2), y:(300,);
  2. 将数据聚类为9类;
  3. 输入数据X,训练;
  4. 预测X的类别,为 y_kmeans
  5. 使用scatter绘制散点图,颜色范围是viridis;
  6. 获取聚类中心 cluster_centers_ ,以黑色(black)点表示;

源码:

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()  # for plot styling
from sklearn.cluster import KMeans
from sklearn.datasets.samples_generator import make_blobs


def test_of_k_means():
    # 创建测试点,X是数据,y是标签,X:(300,2), y:(300,)
    X, y_true = make_blobs(n_samples=300, centers=9, cluster_std=0.60, random_state=0)
    kmeans = KMeans(n_clusters=9)  # 将数据聚类
    kmeans.fit(X)  # 数据X
    y_kmeans = kmeans.predict(X)  # 预测

    # 颜色范围viridis: https://matplotlib.org/examples/color/colormaps_reference.html
    plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=20, cmap='viridis')  # c是颜色,s是大小

    centers = kmeans.cluster_centers_  # 聚类的中心
    plt.scatter(centers[:, 0], centers[:, 1], c='black', s=40, alpha=0.5)  # 中心点为黑色

    plt.show()  # 展示


if __name__ == '__main__':
    test_of_k_means()
复制代码

输出

![K-Means](quiver-image-url/89B03173B750CE7B3462ECA578F6C40E.jpg =400x480)

补充2. EarlyStopping

EarlyStopping是Callback(回调类)的子类,Callback用于指定在每个阶段开始和结束的时候执行的操作。在Callback中,有一些已经实现的简单接口,如 accval_acclossval_loss 等,还有一些复杂接口,如 ModelCheckpoint (用于存储模型参数)和 TensorBoard (用于画图)。

常见的Callback回调接口:

def on_epoch_begin(self, epoch, logs=None):
def on_epoch_end(self, epoch, logs=None):
def on_batch_begin(self, batch, logs=None):
def on_batch_end(self, batch, logs=None):
def on_train_begin(self, logs=None):
def on_train_end(self, logs=None):
复制代码

EarlyStopping则是用于提前停止训练的Callback。具体地,当训练或验证集中的loss不再减小,即减小的程度小于某个阈值时,停止训练,提高调参效率,避免浪费资源。

在model的fit数据中,设置callbacks回调,列表形式,支持设置多个,如:

callbacks=[logging, checkpoint, reduce_lr, early_stopping]
复制代码

EarlyStopping的参数:

  • monitor:监控数据的类型,支持acc、 val_acc 、loss、 val_loss 等;
  • min_delta :停止的阈值,与mode参数配合,增加或下降最少的阈值;
  • mode:min是最少,max是最多,auto是自动,与 min_delta 配合;
  • patience:达到阈值之后,能够容忍的epoch数,避免停止过早;
  • verbose:日志的繁杂程度,值越大,输出的信息越多。

min_delta 和patience需要相互配合,避免模型停止在抖动过程中,在设置的时候,需要相互协调。一般而言, min_delta 降低,patience适当减少; min_delta 增加,则patience适当延长。

实例:

early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
复制代码

OK, that's all! Enjoy it!


以上所述就是小编给大家介绍的《探索 YOLO v3 实现 - 训练 1》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

HTML5

HTML5

Matthew David / Focal Press / 2010-07-29 / USD 39.95

Implement the powerful new multimedia and interactive capabilities offered by HTML5, including style control tools, illustration tools, video, audio, and rich media solutions. Understand how HTML5 is ......一起来看看 《HTML5》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换