梯度提升算法决策过程的逐步可视化
大家好,我们又见面了啊~本文《梯度提升算法决策过程的逐步可视化》的内容中将会涉及到等等。如果你正在学习科技周边相关知识,欢迎关注我,以后会给大家带来更多科技周边相关文章,希望我们能一起进步!下面就开始本文的正式内容~
梯度提升算法是最常用的集成机器学习技术之一,该模型使用弱决策树序列来构建强学习器。这也是XGBoost和LightGBM模型的理论基础,所以在这篇文章中,我们将从头开始构建一个梯度增强模型并将其可视化。
梯度提升算法介绍
梯度提升算法(Gradient Boosting)是一种集成学习算法,它通过构建多个弱分类器,然后将它们组合成一个强分类器来提高模型的预测准确率。
梯度提升算法的原理可以分为以下几个步骤:
- 初始化模型:一般来说,我们可以使用一个简单的模型(比如说决策树)作为初始的分类器。
- 计算损失函数的负梯度:计算出每个样本点在当前模型下的损失函数的负梯度。这相当于是让新的分类器去拟合当前模型下的误差。
- 训练新的分类器:用这些负梯度作为目标变量,训练一个新的弱分类器。这个弱分类器可以是任意的分类器,比如说决策树、线性模型等。
- 更新模型:将新的分类器加入到原来的模型中,可以用加权平均或者其他方法将它们组合起来。
- 重复迭代:重复上述步骤,直到达到预设的迭代次数或者达到预设的准确率。
由于梯度提升算法是一种串行算法,所以它的训练速度可能会比较慢,我们以一个实际的例子来介绍:
假设我们有一个特征集Xi和值Yi,要计算y的最佳估计
我们从y的平均值开始
每一步我们都想让F_m(x)更接近y|x。
在每一步中,我们都想要F_m(x)一个更好的y给定x的近似。
首先,我们定义一个损失函数
然后,我们向损失函数相对于学习者Fm下降最快的方向前进:
因为我们不能为每个x计算y,所以不知道这个梯度的确切值,但是对于训练数据中的每一个x_i,梯度完全等于步骤m的残差:r_i!
所以我们可以用弱回归树h_m来近似梯度函数g_m,对残差进行训练:
然后,我们更新学习器
这就是梯度提升,我们不是使用损失函数相对于当前学习器的真实梯度g_m来更新当前学习器F_{m},而是使用弱回归树h_m来更新它。
也就是重复下面的步骤
1、计算残差:
2、将回归树h_m拟合到训练样本及其残差(x_i, r_i)上
3、用步长alpha更新模型
看着很复杂对吧,下面我们可视化一下这个过程就会变得非常清晰了
决策过程可视化
这里我们使用sklearn的moons 数据集,因为这是一个经典的非线性分类数据
import numpy as np import sklearn.datasets as ds import pandas as pd import matplotlib.pyplot as plt import matplotlib as mpl from sklearn import tree from itertools import product,islice import seaborn as snsmoonDS = ds.make_moons(200, noise = 0.15, random_state=16) moon = moonDS[0] color = -1*(moonDS[1]*2-1) df =pd.DataFrame(moon, columns = ['x','y']) df['z'] = color df['f0'] =df.y.mean() df['r0'] = df['z'] - df['f0'] df.head(10)
让我们可视化数据:
下图可以看到,该数据集是可以明显的区分出分类的边界的,但是因为他是非线性的,所以使用线性算法进行分类时会遇到很大的困难。
那么我们先编写一个简单的梯度增强模型:
def makeiteration(i:int): """Takes the dataframe ith f_i and r_i and approximated r_i from the features, then computes f_i+1 and r_i+1""" clf = tree.DecisionTreeRegressor(max_depth=1) clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}']) df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values) eta = 0.9 df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat'] df[f'r{i}'] = df['z'] - df[f'f{i}'] rmse = (df[f'r{i}']**2).sum() clfs.append(clf) rmses.append(rmse)
上面代码执行3个简单步骤:
将决策树与残差进行拟合:
clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}']) df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)
然后,我们将这个近似的梯度与之前的学习器相加:
df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']
最后重新计算残差:
df[f'r{i}'] = df['z'] - df[f'f{i}']
步骤就是这样简单,下面我们来一步一步执行这个过程。
第1次决策
Tree Split for 0 and level 1.563690960407257
第2次决策
Tree Split for 1 and level 0.5143677890300751
第3次决策
Tree Split for 0 and level -0.6523728966712952
第4次决策
Tree Split for 0 and level 0.3370491564273834
第5次决策
Tree Split for 0 and level 0.3370491564273834
第6次决策
Tree Split for 1 and level 0.022058885544538498
第7次决策
Tree Split for 0 and level -0.3030575215816498
第8次决策
Tree Split for 0 and level 0.6119407713413239
第9次决策
可以看到通过9次的计算,基本上已经把上面的分类进行了区分
我们这里的学习器都是非常简单的决策树,只沿着一个特征分裂!但整体模型在每次决策后边的越来越复杂,并且整体误差逐渐减小。
plt.plot(rmses)
这也就是上图中我们看到的能够正确区分出了大部分的分类
如果你感兴趣可以使用下面代码自行实验:
https://github.com/trenaudie/GradientBoostingVisualized/blob/main/fromScratch.ipynb
今天关于《梯度提升算法决策过程的逐步可视化》的内容就介绍到这里了,是不是学起来一目了然!想要了解更多关于机器学习,梯度提升算法的内容请关注golang学习网公众号!

- 上一篇
- 利用ChatGPT热潮,企业软件初创公司竞相融资

- 下一篇
- 超5800亿美元!微软谷歌神仙打架,让英伟达市值飙升,约为5个英特尔
-
- 科技周边 · 人工智能 | 2小时前 |
- PerplexityAI能分析地壳运动吗?
- 325浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- Android集成MLKit,AI功能实战教程
- 319浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- AI剪辑10分钟生成短视频全解析
- 425浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- 2025上半年自主品牌销量排名小米SU7第五
- 351浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- Deepseek+Descript,专业剪辑新体验
- 413浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- HuggingFace模型使用与加载教程
- 142浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 | 视觉设计 DecktopusAI 活动报名率 邀请页 智能内容生成
- DecktopusAI如何提升邀请页转化率
- 390浏览 收藏
-
- 科技周边 · 人工智能 | 2小时前 |
- 7月汽车产销超259万,新能源车出口领先
- 234浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 511次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 498次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- 千音漫语
- 千音漫语,北京熠声科技倾力打造的智能声音创作助手,提供AI配音、音视频翻译、语音识别、声音克隆等强大功能,助力有声书制作、视频创作、教育培训等领域,官网:https://qianyin123.com
- 151次使用
-
- MiniWork
- MiniWork是一款智能高效的AI工具平台,专为提升工作与学习效率而设计。整合文本处理、图像生成、营销策划及运营管理等多元AI工具,提供精准智能解决方案,让复杂工作简单高效。
- 143次使用
-
- NoCode
- NoCode (nocode.cn)是领先的无代码开发平台,通过拖放、AI对话等简单操作,助您快速创建各类应用、网站与管理系统。无需编程知识,轻松实现个人生活、商业经营、企业管理多场景需求,大幅降低开发门槛,高效低成本。
- 157次使用
-
- 达医智影
- 达医智影,阿里巴巴达摩院医疗AI创新力作。全球率先利用平扫CT实现“一扫多筛”,仅一次CT扫描即可高效识别多种癌症、急症及慢病,为疾病早期发现提供智能、精准的AI影像早筛解决方案。
- 150次使用
-
- 智慧芽Eureka
- 智慧芽Eureka,专为技术创新打造的AI Agent平台。深度理解专利、研发、生物医药、材料、科创等复杂场景,通过专家级AI Agent精准执行任务,智能化工作流解放70%生产力,让您专注核心创新。
- 159次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览