内容简介:作者:Freemanzxp简介:中科大研二在读,目前在微软亚洲研究院实习,主要研究方向是机器学习。原文:https://blog.csdn.net/zpalyq110/article/details/79527653
作者:Freemanzxp
简介:中科大研二在读,目前在微软亚洲研究院实习,主要研究方向是机器学习。
原文:https://blog.csdn.net/zpalyq110/article/details/79527653
Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial
本文已授权,未经原作者允许,不得二次转载
这是来自读者的一篇投稿,因为公众号对 Latex 公式支持不是很好,所以可以点击文末 “阅读原文“ 进行阅读。同时也希望觉得有帮助的欢迎到作者的 Github 上 star !
写在前面:去年学习 GBDT 之初,为了加强对算法的理解,整理了一篇笔记形式的文章,发出去之后发现阅读量越来越多,渐渐也有了评论,评论中大多指出来了笔者理解或者编辑的错误,故重新编辑一版文章,内容更加翔实,并且在 GitHub 上实现了和本文一致的 GBDT 简易版(包括回归、二分类、多分类以及可视化),供大家交流探讨。感谢各位的点赞和评论,希望继续指出错误~
Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial
简介:
GBDT的全称是 Gradient Boosting Decision Tree,梯度提升树,在传统机器学习算法中,GBDT 算得上 TOP3 的算法。想要理解 GBDT 的真正意义,那就必须理解 GBDT 中的 Gradient Boosting 和 Decision Tree 分别是什么?
1. Decision Tree:CART回归树
首先, GBDT 使用的决策树是 CART 回归树 ,无论是处理回归问题还是二分类以及多分类,GBDT 使用的决策树通通都是都是 CART 回归树。
为什么不用 CART 分类树呢?因为 GBDT 每次迭代要拟合的是 梯度值 ,是 连续值 所以要用回归树。
对于回归树算法来说最重要的是 寻找最佳的划分点 ,那么回归树中的可划分点包含了所有特征的所有可取的值。在 分类树中最佳划分点的判别标准是熵或者基尼系数 ,都是用纯度来衡量的,但是在回归树中的样本标签是连续数值,所以再使用熵之类的指标不再合适,取而代之的是 平方误差 ,它能很好的评判拟合程度。
回归树生成算法:
输入:训练数据集 D
:
输出:回归树 f(x)
.
在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:
(1)选择最优切分变量 j
与切分点 s
,求解
遍历变量 j
,对固定的切分变量 j
扫描切分点 s
,选择使得上式达到最小值的对 (j,s)
.
(2)用选定的对 (j,s)
划分区域并决定相应的输出值:
(3)继续对两个子区域调用步骤(1)和(2),直至满足停止条件。
(4)将输入空间划分为 M
个区域
,生成决策树:
2. Gradient Boosting:拟合负梯度
梯度提升树(Grandient Boosting)是提升树(Boosting Tree)的一种改进算法,所以在讲梯度提升树之前先来说一下提升树。
先来个通俗理解:
假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。
如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。
最后将每次拟合的岁数加起来便是模型输出的结果。
提升树算法:
(1)初始化
(2)对
(a)计算残差
(b)拟合残差学习一个回归树,得到
(c)更新
(3)得到回归问题提升树
上面伪代码中的 残差 是什么?
在提升树算法中,假设我们前一轮迭代得到的强学习器是
损失函数是
我们本轮迭代的目标是找到一个弱学习器
最小化让本轮的损失
当采用平方损失函数时
这里,
是当前模型拟合数据的残差(residual)
所以,对于提升树来说只需要简单地拟合当前模型的残差。
回到我们上面讲的那个通俗易懂的例子中,第一次迭代的残差是10岁,第二 次残差4岁……
当损失函数是平方损失和指数损失函数时,梯度提升树每一步优化是很简单的,但是对于一般损失函数而言,往往每一步优化起来不那么容易,针对这一问题,Freidman 提出了梯度提升树算法,这是利用最速下降的近似方法, 其关键是利用损失函数的负梯度作为提升树算法中的残差的近似值。
那么负梯度长什么样呢?
第 t 轮的第 i 个样本的损失函数的负梯度为:
此时不同的损失函数将会得到不同的负梯度,如果选择平方损失
负梯度为
此时我们发现 GBDT 的 负梯度就是残差 ,所以说对于回归问题, 我们要拟合的就是残差 。
那么对于分类问题呢?二分类和多分类的损失函数都是 log(loss)
, 本文以回归问题为例进行讲解 。
3. GBDT算法原理
上面两节分别将 Decision Tree 和 Gradient Boosting 介绍完了,下面将这两部分组合在一起就是我们的 GBDT 了。
GBDT算法:
(1)初始化弱学习器
(2)对 有:
(a)对每个样本
,计算负梯度,即残差
(b)将上步得到的残差作为样本新的真实值,并将数据
作为下棵树的训练数据,得到一颗新的回归树
其对应的叶子节点区域为 。其
中 J 为回归树 t 的叶子节点的个数。
(c)对叶子区域
计算最佳拟合值
(d)更新强学习器
(3)得到最终学习器
4. 实例详解
本人用 python 以及 pandas 库实现 GBDT 的简易版本,在下面的例子中用到的数据都在 github 可以找到,大家可以结合代码和下面的例子进行理解,欢迎 star~
Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial
数据介绍:
如下表所示:一组数据,特征为年龄、体重,身高为标签值。共有5条数据,前四条为训练样本,最后一条为要预测的样本。
训练阶段:
参数设置:
-
学习率:learning_rate=0.1
-
迭代次数:n_trees=5
-
树的深度:max_depth=3
1.初始化弱学习器:
损失函数为平方损失,因为平方损失函数是一个凸函数,直接求导,倒数等于零,得到 c 。
令导数等于0
所以初始化时,c取值为所有训练样本标签值的均值。
c=(1.1+1.3+1.7+1.8)/4=1.475,此时得到初始学习器
2.对迭代轮数m=1,2,…,M:
由于我们设置了迭代次数:n_trees=5,这里的 M=5
。
计算负梯度,根据上文损失函数为平方损失时,负梯度就是残差残差,再直白一点就是 y 与上一轮得到的学习器 的差值
残差在下表列出:
此时将残差作为样本的真实值来训练弱学习器 ,即下表数据
接着,寻找回归树的最佳划分节点,遍历每个特征的每个可能取值。 从年龄特征的5开始,到体重特征的 70 结束,分别计算分裂后两组数据的平方损失(Square Error), 左节点平方损失, 右节点平方损失,找到使平方损失和 最小的那个划分节点,即为最佳划分节点。
例如:以年龄 7 为划分节点,将小于 7 的样本划分为到左节点,大于等于 7 的样本划分为右节点。左节点包括 x0,右节点包括样本
,
,所有可能划分情况如下表所示:
以上划分点是的总平方损失最小为 0.025 有两个划分点:年龄21和体重60,所以随机选一个作为划分点,这里我们选 年龄21
现在我们的第一棵树长这个样子:
我们设置的参数中树的深度 max_depth=3
,现在树的深度只有 2,需要再进行一次划分,这次划分要对左右两个节点分别进行划分:
对于 左节点 ,只含有 0,1 两个样本,根据下表我们选择 年龄7 划分
对于 右节点 ,只含有 2,3 两个样本,根据下表我们选择 年龄30 划分(也可以选 体重70 )
现在我们的第一棵树长这个样子:
此时我们的树深度满足了设置,还需要做一件事情,给这每个叶子节点分别赋一个参数 γ,来拟合残差。
这里其实和上面初始化学习器是一个道理,平方损失,求导,令导数等于零,化简之后得到每个叶子节点的参数 γ,其实就是标签值的均值。这个地方的标签值不是原始的 y,而是本轮要拟合的标残差 .
根据上述划分结果,为了方便表示,规定从左到右为第 个叶子结点
此时的树长这个样子:
此时可更新强学习器,需要用到参数学习率:learning_rate=0.1,用 lr 表示。
为什么要用学习率呢?这是 Shrinkage 的思想,如果每次都全部加上(学习率为1)很容易一步学到位导致过拟合。
重复此步骤,直到 结束,最后生成5棵树。
下面将展示每棵树最终的结构,这些图都是GitHub上的代码生成的,感兴趣的同学可以去一探究竟
https://github.com/Freemanzxp/GBDT_Simple_Tutorial
第一棵树:
第二棵树:
第三棵树:
第四棵树:
第五棵树:
4.得到最后的强学习器:
5.预测样本5:
在 中,样本4的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为 0.2250 。
在 中,样本4的…此处省略…所以被预测为 0.2025
为什么是 0.2025 ?
这是根据第二颗树得到的,可以 GitHub 简单运行一下代码
在 中,样本4的…此处省略…所以被预测为 0.1823
在 中,样本4的…此处省略…所以被预测为 0.1640
在 中,样本4的…此处省略…所以被预测为 0.1476
最终预测结果:
5. 总结
本文章从GBDT算法的原理到实例详解进行了详细描述,但是目前只写了回归问题,GitHub 上的代码也是实现了回归、二分类、多分类以及树的可视化,希望大家继续批评指正,感谢各位的关注。
Github:
https://github.com/Freemanzxp/GBDT_Simple_Tutorial
参考资料
-
李航 《统计学习方法》
-
Friedman J H . Greedy Function Approximation: A Gradient Boosting Machine[J]. The Annals of Statistics, 2001, 29(5):1189-1232.
欢迎关注我的微信公众号--机器学习与计算机视觉,或者扫描下方的二维码,大家一起交流,学习和进步!
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:- CRF用过了,不妨再了解下更快的MEMM?
- Unity引擎模块分析—注重效率的你不妨这么做!
- 隐私安全岌岌可危?不妨试试这16个“密门暗器”
- 玩机不知从何下手,不妨收下这份通用「刷机」指南
- Java 11 将至,不妨了解一下 Oracle JDK 之外的版本
- Java 11 将至,不妨了解一下 Oracle JDK 之外的版本
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。