内容简介:YOLO,即You Only Look Once(你只看一次)的缩写,是一个基于卷积神经网络(CNN)的物体检测算法。而YOLOv3是YOLO的第3个版本(即YOLO、YOLO是一句美国的俗语,You Only Live Once,人生苦短,及时行乐。本文介绍如何实现YOLO v3算法,Keras框架。这是第2篇,模型。当然还有第3篇,至第n篇,这是一个完整版 :)
YOLO,即You Only Look Once(你只看一次)的缩写,是一个基于卷积神经网络(CNN)的物体检测算法。而YOLOv3是YOLO的第3个版本(即YOLO、 YOLO 9000 、YOLO v3),检测效果,更准更强。
YOLO是一句美国的俗语,You Only Live Once,人生苦短,及时行乐。
本文介绍如何实现YOLO v3算法,Keras框架。这是第2篇,模型。当然还有第3篇,至第n篇,这是一个完整版 :)
第1篇训练:https://juejin.im/post/5b63c0f8518825631e21d6ea
本文的 源码 :https://github.com/SpikeKing/keras-yolo3-detection
模型
在训练中,调用 create_model
方法创建模型,其中,重点分析 create_model
的逻辑。
在 create_model
方法中,创建YOLO v3的网络结构,其中参数:
input_shape anchors num_classes load_pretrained freeze_body weights_path
如下:
def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=2, weights_path='model_data/yolo_weights.h5'): 复制代码
将参数进行处理:
image_input num_anchors y_true
如下:
h, w = input_shape # 尺寸 image_input = Input(shape=(w, h, 3)) # 图片输入格式 num_anchors = len(anchors) # anchor数量 # YOLO的三种尺度,每个尺度的anchor数,类别数+边框4个+置信度1 y_true = [Input(shape=(h // {0: 32, 1: 16, 2: 8}[l], w // {0: 32, 1: 16, 2: 8}[l], num_anchors // 3, num_classes + 5)) for l in range(3)] 复制代码
其中,真值 y_true
,真值即Ground Truth:
“//”是 Python 语法中的整除符号,通过循环创建3个Input层,组成列表,作为 y_true
,格式如下:
Tensor("input_2:0", shape=(?, 13, 13, 3, 6), dtype=float32) Tensor("input_3:0", shape=(?, 26, 26, 3, 6), dtype=float32) Tensor("input_4:0", shape=(?, 52, 52, 3, 6), dtype=float32) 复制代码
其中,第1位是样本数,第2~3位是特征图的尺寸13x13,第4位是每个图的anchor数,第5位是:类别(n)+4个框值(x,y,w,h)+框的置信度(是否含有物体)。
通过图片输入Input层 image_input
、每个尺度的anchor数 num_anchors//3
、类别数 num_classes
,创建YOLO v3的网络结构,即:
model_body = yolo_body(image_input, num_anchors // 3, num_classes) 复制代码
接着,加载预训练模型:
- 根据预训练模型的地址
weights_path
,加载模型,按名称对应by_name
,略过不匹配skip_mismatch
; - 选择冻结模式:
model_body.layers[i].trainable=True
实现:
if load_pretrained: # 加载预训练模型 model_body.load_weights(weights_path, by_name=True, skip_mismatch=True) if freeze_body in [1, 2]: # Freeze darknet53 body or freeze all but 3 output layers. num = (185, len(model_body.layers) - 3)[freeze_body - 1] for i in range(num): model_body.layers[i].trainable = False # 将其他层的训练关闭 复制代码
接着,设置模型损失层 model_loss
:
- Lambda是Keras的自定义层,输入为
(model_body.output + y_true)
,输出为output_shape=(1,)
; - 层的名字name为
yolo_loss
; - 参数为anchors锚框、类别数
num_classes
,ignore_thresh
是物体置信度损失(object confidence loss)的IoU(Intersection over Union,重叠度)阈值; -
yolo_loss
是核心的损失函数。
实现:
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss', arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5} )(model_body.output + y_true) 复制代码
接着,创建最终模型:
- 模型的输入:
model_body
的输入层,即image_input
,和y_true
; - 模型的输出:
model_loss
的输出,一个值,output_shape=(1,)
; - 保存模型的网络图
plot_model
,和打印网络model.summary()
;
即:
model = Model(inputs=[model_body.input] + y_true, outputs=model_loss) # 模型 plot_model(model, to_file=os.path.join('model_data', 'model.png'), show_shapes=True, show_layer_names=True) # 存储网络结构 model.summary() # 打印网络 复制代码
其中, model_body.input
是任意(?)个 (416,416,3)
的图片,即:
Tensor("input_1:0", shape=(?, 416, 416, 3), dtype=float32) 复制代码
y_true
是已标注数据转换的真值结构,即:
[Tensor("input_2:0", shape=(?, 13, 13, 3, 6), dtype=float32), Tensor("input_3:0", shape=(?, 26, 26, 3, 6), dtype=float32), Tensor("input_4:0", shape=(?, 52, 52, 3, 6), dtype=float32)] 复制代码
补 IoU
IoU,即Intersection over Union,用于计算两个图的重叠度,用于计算两个标注框(bounding box)之间的相关度,值越高,相关度越高。在NMS(Non-Maximum Suppression,非极大值抑制)或计算mAP(mean Average Precision)中,都会使用IoU判断两个框的相关性。
如图:
实现:
def bb_intersection_over_union(boxA, boxB): boxA = [int(x) for x in boxA] boxB = [int(x) for x in boxB] xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) iou = interArea / float(boxAArea + boxBArea - interArea) return iou 复制代码
OK, that's all! Enjoy it!
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:- 搜狗获全球口语翻译大赛冠军 多模型融合细节打磨成制胜法宝
- 从一道 iOS 面试题到 Swift 对象模型和运行时细节——「iOS 面试之道」勘误
- MQTT Essential 细节笔记总结(深入理解MQTT细节)
- MetInfo 7.0.0 20200326 细节优化补丁,主要优化商城相关细节
- MetInfo7.0.0 20200407 细节优化补丁,修复编辑及手机端细节
- php 的小细节
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。