如果你还不了解 GBDT,不妨看看这篇文章

栏目: 数据库 · 发布时间: 5年前

内容简介:作者:Freemanzxp简介:中科大研二在读,目前在微软亚洲研究院实习,主要研究方向是机器学习。原文:https://blog.csdn.net/zpalyq110/article/details/79527653

作者:Freemanzxp

简介:中科大研二在读,目前在微软亚洲研究院实习,主要研究方向是机器学习。

原文:https://blog.csdn.net/zpalyq110/article/details/79527653

Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial

本文已授权,未经原作者允许,不得二次转载

这是来自读者的一篇投稿,因为公众号对 Latex 公式支持不是很好,所以可以点击文末 “阅读原文“ 进行阅读。同时也希望觉得有帮助的欢迎到作者的 Github 上 star !

写在前面:去年学习 GBDT 之初,为了加强对算法的理解,整理了一篇笔记形式的文章,发出去之后发现阅读量越来越多,渐渐也有了评论,评论中大多指出来了笔者理解或者编辑的错误,故重新编辑一版文章,内容更加翔实,并且在 GitHub 上实现了和本文一致的 GBDT 简易版(包括回归、二分类、多分类以及可视化),供大家交流探讨。感谢各位的点赞和评论,希望继续指出错误~

Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial

简介:

GBDT的全称是 Gradient Boosting Decision Tree,梯度提升树,在传统机器学习算法中,GBDT 算得上 TOP3 的算法。想要理解 GBDT 的真正意义,那就必须理解 GBDT 中的 Gradient Boosting 和 Decision Tree 分别是什么?

1. Decision Tree:CART回归树

首先, GBDT 使用的决策树是 CART 回归树 ,无论是处理回归问题还是二分类以及多分类,GBDT 使用的决策树通通都是都是 CART 回归树。

为什么不用 CART 分类树呢?因为 GBDT 每次迭代要拟合的是 梯度值 ,是 连续值 所以要用回归树。

对于回归树算法来说最重要的是 寻找最佳的划分点 ,那么回归树中的可划分点包含了所有特征的所有可取的值。在 分类树中最佳划分点的判别标准是熵或者基尼系数 ,都是用纯度来衡量的,但是在回归树中的样本标签是连续数值,所以再使用熵之类的指标不再合适,取而代之的是 平方误差 ,它能很好的评判拟合程度。

回归树生成算法:

输入:训练数据集 D :

输出:回归树 f(x) .

在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:

(1)选择最优切分变量 j 与切分点  s ,求解

如果你还不了解 GBDT,不妨看看这篇文章

遍历变量 j ,对固定的切分变量  j 扫描切分点  s ,选择使得上式达到最小值的对  (j,s) .

(2)用选定的对 (j,s) 划分区域并决定相应的输出值:

如果你还不了解 GBDT,不妨看看这篇文章

(3)继续对两个子区域调用步骤(1)和(2),直至满足停止条件。

(4)将输入空间划分为 M 个区域 

如果你还不了解 GBDT,不妨看看这篇文章

,生成决策树:

如果你还不了解 GBDT,不妨看看这篇文章

2. Gradient Boosting:拟合负梯度

梯度提升树(Grandient Boosting)是提升树(Boosting Tree)的一种改进算法,所以在讲梯度提升树之前先来说一下提升树。

先来个通俗理解:

假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。

如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。

最后将每次拟合的岁数加起来便是模型输出的结果。

提升树算法:

(1)初始化  如果你还不了解 GBDT,不妨看看这篇文章

(2)对 如果你还不了解 GBDT,不妨看看这篇文章

(a)计算残差

如果你还不了解 GBDT,不妨看看这篇文章

(b)拟合残差学习一个回归树,得到 如果你还不了解 GBDT,不妨看看这篇文章

(c)更新

如果你还不了解 GBDT,不妨看看这篇文章

(3)得到回归问题提升树

如果你还不了解 GBDT,不妨看看这篇文章

上面伪代码中的 残差 是什么?

在提升树算法中,假设我们前一轮迭代得到的强学习器是

如果你还不了解 GBDT,不妨看看这篇文章

损失函数是

如果你还不了解 GBDT,不妨看看这篇文章

我们本轮迭代的目标是找到一个弱学习器

如果你还不了解 GBDT,不妨看看这篇文章

最小化让本轮的损失

如果你还不了解 GBDT,不妨看看这篇文章

当采用平方损失函数时

如果你还不了解 GBDT,不妨看看这篇文章

这里,

如果你还不了解 GBDT,不妨看看这篇文章

是当前模型拟合数据的残差(residual)

所以,对于提升树来说只需要简单地拟合当前模型的残差。

回到我们上面讲的那个通俗易懂的例子中,第一次迭代的残差是10岁,第二 次残差4岁……

当损失函数是平方损失和指数损失函数时,梯度提升树每一步优化是很简单的,但是对于一般损失函数而言,往往每一步优化起来不那么容易,针对这一问题,Freidman 提出了梯度提升树算法,这是利用最速下降的近似方法, 其关键是利用损失函数的负梯度作为提升树算法中的残差的近似值。

那么负梯度长什么样呢?

第 t 轮的第 i 个样本的损失函数的负梯度为:

如果你还不了解 GBDT,不妨看看这篇文章

此时不同的损失函数将会得到不同的负梯度,如果选择平方损失

如果你还不了解 GBDT,不妨看看这篇文章

负梯度为

如果你还不了解 GBDT,不妨看看这篇文章

此时我们发现 GBDT 的 负梯度就是残差 ,所以说对于回归问题, 我们要拟合的就是残差

那么对于分类问题呢?二分类和多分类的损失函数都是 log(loss)本文以回归问题为例进行讲解

3. GBDT算法原理

上面两节分别将 Decision Tree 和 Gradient Boosting 介绍完了,下面将这两部分组合在一起就是我们的 GBDT 了。

GBDT算法:

(1)初始化弱学习器

如果你还不了解 GBDT,不妨看看这篇文章

(2)对 如果你还不了解 GBDT,不妨看看这篇文章 有:

(a)对每个样本

如果你还不了解 GBDT,不妨看看这篇文章

,计算负梯度,即残差

如果你还不了解 GBDT,不妨看看这篇文章

(b)将上步得到的残差作为样本新的真实值,并将数据

如果你还不了解 GBDT,不妨看看这篇文章

作为下棵树的训练数据,得到一颗新的回归树

如果你还不了解 GBDT,不妨看看这篇文章

其对应的叶子节点区域为 如果你还不了解 GBDT,不妨看看这篇文章 。其

中 J 为回归树 t 的叶子节点的个数。

(c)对叶子区域

如果你还不了解 GBDT,不妨看看这篇文章

计算最佳拟合值

如果你还不了解 GBDT,不妨看看这篇文章

(d)更新强学习器

如果你还不了解 GBDT,不妨看看这篇文章

(3)得到最终学习器

如果你还不了解 GBDT,不妨看看这篇文章

4. 实例详解

本人用 python 以及 pandas 库实现 GBDT 的简易版本,在下面的例子中用到的数据都在 github 可以找到,大家可以结合代码和下面的例子进行理解,欢迎 star~

Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial

数据介绍:

如下表所示:一组数据,特征为年龄、体重,身高为标签值。共有5条数据,前四条为训练样本,最后一条为要预测的样本。

如果你还不了解 GBDT,不妨看看这篇文章

训练阶段:

参数设置:

  • 学习率:learning_rate=0.1

  • 迭代次数:n_trees=5

  • 树的深度:max_depth=3

1.初始化弱学习器:

如果你还不了解 GBDT,不妨看看这篇文章

损失函数为平方损失,因为平方损失函数是一个凸函数,直接求导,倒数等于零,得到 c

如果你还不了解 GBDT,不妨看看这篇文章

令导数等于0

如果你还不了解 GBDT,不妨看看这篇文章

所以初始化时,c取值为所有训练样本标签值的均值。

c=(1.1+1.3+1.7+1.8)/4=1.475,此时得到初始学习器 如果你还不了解 GBDT,不妨看看这篇文章

如果你还不了解 GBDT,不妨看看这篇文章

2.对迭代轮数m=1,2,…,M:

由于我们设置了迭代次数:n_trees=5,这里的 M=5

计算负梯度,根据上文损失函数为平方损失时,负梯度就是残差残差,再直白一点就是 y 与上一轮得到的学习器 如果你还不了解 GBDT,不妨看看这篇文章 的差值

如果你还不了解 GBDT,不妨看看这篇文章

残差在下表列出:

如果你还不了解 GBDT,不妨看看这篇文章

此时将残差作为样本的真实值来训练弱学习器 如果你还不了解 GBDT,不妨看看这篇文章 ,即下表数据

如果你还不了解 GBDT,不妨看看这篇文章

接着,寻找回归树的最佳划分节点,遍历每个特征的每个可能取值。 从年龄特征的5开始,到体重特征的 70 结束,分别计算分裂后两组数据的平方损失(Square Error), 左节点平方损失, 如果你还不了解 GBDT,不妨看看这篇文章 右节点平方损失,找到使平方损失和 如果你还不了解 GBDT,不妨看看这篇文章 最小的那个划分节点,即为最佳划分节点。

例如:以年龄 7 为划分节点,将小于 7 的样本划分为到左节点,大于等于 7 的样本划分为右节点。左节点包括 x0,右节点包括样本

如果你还不了解 GBDT,不妨看看这篇文章

如果你还不了解 GBDT,不妨看看这篇文章

,所有可能划分情况如下表所示:

如果你还不了解 GBDT,不妨看看这篇文章

以上划分点是的总平方损失最小为 0.025 有两个划分点:年龄21和体重60,所以随机选一个作为划分点,这里我们选 年龄21

现在我们的第一棵树长这个样子:

如果你还不了解 GBDT,不妨看看这篇文章

我们设置的参数中树的深度 max_depth=3 ,现在树的深度只有 2,需要再进行一次划分,这次划分要对左右两个节点分别进行划分:

对于 左节点 ,只含有 0,1 两个样本,根据下表我们选择 年龄7 划分

如果你还不了解 GBDT,不妨看看这篇文章

对于 右节点 ,只含有 2,3 两个样本,根据下表我们选择 年龄30 划分(也可以选 体重70

如果你还不了解 GBDT,不妨看看这篇文章

现在我们的第一棵树长这个样子:

如果你还不了解 GBDT,不妨看看这篇文章

此时我们的树深度满足了设置,还需要做一件事情,给这每个叶子节点分别赋一个参数 γ,来拟合残差。

如果你还不了解 GBDT,不妨看看这篇文章

这里其实和上面初始化学习器是一个道理,平方损失,求导,令导数等于零,化简之后得到每个叶子节点的参数 γ,其实就是标签值的均值。这个地方的标签值不是原始的 y,而是本轮要拟合的标残差 如果你还不了解 GBDT,不妨看看这篇文章 .

根据上述划分结果,为了方便表示,规定从左到右为第 如果你还不了解 GBDT,不妨看看这篇文章 个叶子结点

如果你还不了解 GBDT,不妨看看这篇文章 如果你还不了解 GBDT,不妨看看这篇文章 如果你还不了解 GBDT,不妨看看这篇文章

如果你还不了解 GBDT,不妨看看这篇文章

此时的树长这个样子:

如果你还不了解 GBDT,不妨看看这篇文章

此时可更新强学习器,需要用到参数学习率:learning_rate=0.1,用 lr 表示。

如果你还不了解 GBDT,不妨看看这篇文章

为什么要用学习率呢?这是 Shrinkage 的思想,如果每次都全部加上(学习率为1)很容易一步学到位导致过拟合。

重复此步骤,直到 如果你还不了解 GBDT,不妨看看这篇文章 结束,最后生成5棵树。

下面将展示每棵树最终的结构,这些图都是GitHub上的代码生成的,感兴趣的同学可以去一探究竟

https://github.com/Freemanzxp/GBDT_Simple_Tutorial

第一棵树:

如果你还不了解 GBDT,不妨看看这篇文章

第二棵树:

如果你还不了解 GBDT,不妨看看这篇文章

第三棵树:

如果你还不了解 GBDT,不妨看看这篇文章

第四棵树:

如果你还不了解 GBDT,不妨看看这篇文章

第五棵树:

如果你还不了解 GBDT,不妨看看这篇文章

4.得到最后的强学习器:

如果你还不了解 GBDT,不妨看看这篇文章

5.预测样本5:

如果你还不了解 GBDT,不妨看看这篇文章

如果你还不了解 GBDT,不妨看看这篇文章 中,样本4的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为 0.2250

如果你还不了解 GBDT,不妨看看这篇文章 中,样本4的…此处省略…所以被预测为 0.2025

为什么是 0.2025

这是根据第二颗树得到的,可以 GitHub 简单运行一下代码

如果你还不了解 GBDT,不妨看看这篇文章 中,样本4的…此处省略…所以被预测为 0.1823

如果你还不了解 GBDT,不妨看看这篇文章 中,样本4的…此处省略…所以被预测为 0.1640

如果你还不了解 GBDT,不妨看看这篇文章 中,样本4的…此处省略…所以被预测为 0.1476

最终预测结果:

5. 总结

本文章从GBDT算法的原理到实例详解进行了详细描述,但是目前只写了回归问题,GitHub 上的代码也是实现了回归、二分类、多分类以及树的可视化,希望大家继续批评指正,感谢各位的关注。

Github:

https://github.com/Freemanzxp/GBDT_Simple_Tutorial

参考资料

  1. 李航 《统计学习方法》

  2. Friedman J H . Greedy Function Approximation: A Gradient Boosting Machine[J]. The Annals of Statistics, 2001, 29(5):1189-1232.

欢迎关注我的微信公众号--机器学习与计算机视觉,或者扫描下方的二维码,大家一起交流,学习和进步!

如果你还不了解 GBDT,不妨看看这篇文章


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

ES6标准入门(第3版)

ES6标准入门(第3版)

阮一峰 / 电子工业出版社 / 2017-9 / 99.00

ES6是下一代JavaScript语言标准的统称,每年6月发布一次修订版,迄今为止已经发布了3个版本,分别是ES2015、ES2016、ES2017。本书根据ES2017标准,详尽介绍了所有新增的语法,对基本概念、设计目的和用法进行了清晰的讲解,给出了大量简单易懂的示例。本书为中级难度,适合那些已经对JavaScript语言有一定了解的读者,可以作为学习这门语言最新进展的工具书,也可以作为参考手册......一起来看看 《ES6标准入门(第3版)》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码