引言
XGBoost(eXtreme Gradient Boosting)是一种高效的梯度提升算法,在数据科学和机器学习领域广泛应用,尤其擅长处理结构化数据的回归和分类问题。以下从数学原理、发展历程、使用方法和竞争算法四个方面进行详细介绍:
一、数学原理
XGBoost 是梯度提升框架(Gradient Boosting)的优化实现,其核心思想是迭代地训练多个弱学习器(通常是决策树),并将它们组合成一个强学习器。
具体数学原理如下:
1.目标函数
XGBoost 的目标函数由两部分组成:
损失函数:衡量模型预测值与真实值之间的误差(如均方误差 MSE、对数损失 log loss 等)。 正则化项:控制模型复杂度,防止过拟合,包括树的数量、深度、叶子节点数等。
目标函数表达式:
\[ \text{Obj}(\theta) = \sum_{i=1}^n L(y_i, \hat{y}_i) + \sum_{k=1}^K \Omega(f_k) \]
其中:
- \(L(y_i, \hat{y}_i)\) 是样本 \(i\) 的损失函数,\(\hat{y}_i\) 是预测值,\(y_i\) 是真实值。
- \(\Omega(f_k)\) 是第 \(k\) 棵树的正则化项,通常包括树的复杂度(如叶子节点数、叶子节点权重的 L2 正则化)。
2.梯度提升迭代
XGBoost 通过迭代方式逐步优化目标函数:
初始化模型为常数预测: \(\hat{y}_i^{(0)} = \text{argmin}_c \sum_{i=1}^n L(y_i, c)\) 。
对于每一轮迭代 \(t\) :
- 计算损失函数关于当前预测值的梯度(一阶导数)和海森矩阵(二阶导数)。
- 训练一棵新的决策树 \(f_t\) ,拟合这些梯度(即预测梯度方向)。
- 更新预测值:\(\hat{y}_i^{(t)} = \hat{y}_i^{(t-1)} + \eta f_t(x_i)\) ,其中 \(\eta\) 是学习率(控制每次迭代的步长)。
3.决策树优化
XGBoost 在构建决策树时采用了以下优化策略:
- 精确贪心算法:通过枚举所有特征的所有可能分割点,找到最优分割。
- 近似算法:对于大规模数据,使用直方图近似快速找到分割点。
- 缺失值处理:自动学习缺失值的最优分裂方向。
- 正则化剪枝:通过正则化项控制树的复杂度,避免过拟合。
二、发展历程
2014 年:陈天奇(Tianqi Chen)在华盛顿大学读博期间开发了 XGBoost 的原型,最初作为研究项目用于解决大规模机器学习问题。
2015 年:XGBoost 作为开源项目发布,因其在 Kaggle 竞赛中的优异表现迅速走红,成为数据科学领域的主流算法之一。
2016 年:陈天奇加入微软,带领团队进一步优化 XGBoost,支持分布式计算和云计算平台。
2017 年至今:XGBoost 持续更新,增加了更多功能(如 DART、GPU 加速、特征重要性分析等),并被集成到多种编程语言和工具中(如 Python、R、Scala 等)。
三、如何使用
我们将用联邦基金有效利率和失业率为输入参数,按照 xgboost 模型对纳斯达克 100 指数月度数据进行建模。
# 加载必要的R包
library(tidyquant) # 获取数据
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
## ── Attaching core tidyquant packages ────────────────── tidyquant 1.0.11.9000 ──
## ✔ PerformanceAnalytics 2.0.8 ✔ TTR 0.24.4
## ✔ quantmod 0.4.28 ✔ xts 0.14.1
## ── Conflicts ────────────────────────────────────────── tidyquant_conflicts() ──
## ✖ zoo::as.Date() masks base::as.Date()
## ✖ zoo::as.Date.numeric() masks base::as.Date.numeric()
## ✖ PerformanceAnalytics::legend() masks graphics::legend()
## ✖ quantmod::summary() masks base::summary()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidyverse) # 数据处理与可视化的核心包集合
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.2 ✔ tibble 3.3.0
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::first() masks xts::first()
## ✖ dplyr::lag() masks stats::lag()
## ✖ dplyr::last() masks xts::last()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels) # 机器学习建模框架
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom 1.0.8 ✔ rsample 1.3.0
## ✔ dials 1.4.0 ✔ tune 1.3.0
## ✔ infer 1.0.9 ✔ workflows 1.2.0
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.1
## ✔ parsnip 1.3.2 ✔ yardstick 1.3.2
## ✔ recipes 1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::first() masks xts::first()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ dplyr::last() masks xts::last()
## ✖ dials::momentum() masks TTR::momentum()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
library(modeltime) # 时间序列建模与预测
##
## Attaching package: 'modeltime'
##
## The following object is masked from 'package:TTR':
##
## growth
library(timetk) # 时间序列数据处理工具
##
## Attaching package: 'timetk'
##
## The following object is masked from 'package:tidyquant':
##
## FANG
# 获取失业率数据 (UNRATE)
df_unrate <-
tq_get("UNRATE", get = "economic.data") %>%
select(date, unrate = price) # 重命名价格列为unrate
# 获取联邦基金有效利率数据 (FEDFUNDS)
df_fedfunds <-
tq_get("FEDFUNDS", get = "economic.data") %>%
select(date, fedfunds = price) # 重命名价格列为fedfunds
# 获取纳斯达克100指数数据
df_nasdaq <-
tq_get("^NDX") %>%
tq_transmute(select = close, # 选择收盘价
mutate_fun = to.monthly, # 转换为月度数据
col_rename = "nasdaq") %>% # 重命名列
mutate(date = as.Date(date)) # 确保日期格式正确
# 合并三个数据集
df_merged <-
df_unrate %>%
left_join(df_fedfunds) %>%
left_join(df_nasdaq) %>%
drop_na() # 删除包含缺失值的行
## Joining with `by = join_by(date)`
## Joining with `by = join_by(date)`
# 时间序列数据分割
splits <-
time_series_split(
df_merged,
assess = "1 year", # 测试集为1年数据
cumulative = TRUE # 训练集包含所有历史数据
)
## Using date_var: date
df_train <- training(splits) # 提取训练集
df_test <- testing(splits) # 提取测试集
# 创建机器学习预处理配方
recipe_ml <-
recipe(nasdaq ~ ., df_train) %>% # 以nasdaq为目标变量
step_date(date, features = "month", ordinal = FALSE) %>% # 从日期提取月份特征
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # 将分类变量转为虚拟变量
step_mutate(date_num = as.numeric(date)) %>% # 将日期转换为数值型
step_normalize(all_numeric_predictors()) %>% # 标准化数值预测变量
step_rm(date) # 移除原始日期列
# 定义XGBoost提升树模型规格,设置超参数为待调优
mod_spec <-
boost_tree(trees = tune(), # 树的数量
min_n = tune(), # 节点最小样本量
tree_depth = tune(), # 树的最大深度
learn_rate = tune()) %>% # 学习率
set_engine("xgboost") %>% # 使用xgboost引擎
set_mode("regression") # 设置为回归模式
# 提取模型的超参数集
mod_param <- extract_parameter_set_dials(mod_spec)
# 设置随机种子确保结果可复现
set.seed(1234)
# 创建随机超参数网格(50种组合)
model_tbl <-
mod_param %>%
grid_random(size = 50) %>%
create_model_grid(
f_model_spec = boost_tree,
engine_name = "xgboost",
mode = "regression"
)
# 从网格中提取模型列表
model_list <- model_tbl$.models
# 创建工作流集,将预处理配方与不同模型规格组合
model_wfset <-
workflow_set(
preproc = list(recipe_ml), # 使用之前定义的预处理配方
models = model_list, # 使用超参数网格生成的模型列表
cross = TRUE # 交叉组合配方和模型
)
# 使用并行计算加速模型训练
model_parallel_tbl <-
model_wfset %>%
modeltime_fit_workflowset(
data = df_train, # 使用训练集数据
control = control_fit_workflowset(
verbose = TRUE, # 显示训练进度
allow_par = TRUE # 允许并行计算
)
)
## ℹ Fitting Model: 1
## ✔ Model Successfully Fitted: 1
## ℹ Fitting Model: 2
## ✔ Model Successfully Fitted: 2
## ℹ Fitting Model: 3
## ✔ Model Successfully Fitted: 3
## ℹ Fitting Model: 4
## ✔ Model Successfully Fitted: 4
## ℹ Fitting Model: 5
## ✔ Model Successfully Fitted: 5
## ℹ Fitting Model: 6
## ✔ Model Successfully Fitted: 6
## ℹ Fitting Model: 7
## ✔ Model Successfully Fitted: 7
## ℹ Fitting Model: 8
## ✔ Model Successfully Fitted: 8
## ℹ Fitting Model: 9
## ✔ Model Successfully Fitted: 9
## ℹ Fitting Model: 10
## ✔ Model Successfully Fitted: 10
## ℹ Fitting Model: 11
## ✔ Model Successfully Fitted: 11
## ℹ Fitting Model: 12
## ✔ Model Successfully Fitted: 12
## ℹ Fitting Model: 13
## ✔ Model Successfully Fitted: 13
## ℹ Fitting Model: 14
## ✔ Model Successfully Fitted: 14
## ℹ Fitting Model: 15
## ✔ Model Successfully Fitted: 15
## ℹ Fitting Model: 16
## ✔ Model Successfully Fitted: 16
## ℹ Fitting Model: 17
## ✔ Model Successfully Fitted: 17
## ℹ Fitting Model: 18
## ✔ Model Successfully Fitted: 18
## ℹ Fitting Model: 19
## ✔ Model Successfully Fitted: 19
## ℹ Fitting Model: 20
## ✔ Model Successfully Fitted: 20
## ℹ Fitting Model: 21
## ✔ Model Successfully Fitted: 21
## ℹ Fitting Model: 22
## ✔ Model Successfully Fitted: 22
## ℹ Fitting Model: 23
## ✔ Model Successfully Fitted: 23
## ℹ Fitting Model: 24
## ✔ Model Successfully Fitted: 24
## ℹ Fitting Model: 25
## ✔ Model Successfully Fitted: 25
## ℹ Fitting Model: 26
## ✔ Model Successfully Fitted: 26
## ℹ Fitting Model: 27
## ✔ Model Successfully Fitted: 27
## ℹ Fitting Model: 28
## ✔ Model Successfully Fitted: 28
## ℹ Fitting Model: 29
## ✔ Model Successfully Fitted: 29
## ℹ Fitting Model: 30
## ✔ Model Successfully Fitted: 30
## ℹ Fitting Model: 31
## ✔ Model Successfully Fitted: 31
## ℹ Fitting Model: 32
## ✔ Model Successfully Fitted: 32
## ℹ Fitting Model: 33
## ✔ Model Successfully Fitted: 33
## ℹ Fitting Model: 34
## ✔ Model Successfully Fitted: 34
## ℹ Fitting Model: 35
## ✔ Model Successfully Fitted: 35
## ℹ Fitting Model: 36
## ✔ Model Successfully Fitted: 36
## ℹ Fitting Model: 37
## ✔ Model Successfully Fitted: 37
## ℹ Fitting Model: 38
## ✔ Model Successfully Fitted: 38
## ℹ Fitting Model: 39
## ✔ Model Successfully Fitted: 39
## ℹ Fitting Model: 40
## ✔ Model Successfully Fitted: 40
## ℹ Fitting Model: 41
## ✔ Model Successfully Fitted: 41
## ℹ Fitting Model: 42
## ✔ Model Successfully Fitted: 42
## ℹ Fitting Model: 43
## ✔ Model Successfully Fitted: 43
## ℹ Fitting Model: 44
## ✔ Model Successfully Fitted: 44
## ℹ Fitting Model: 45
## ✔ Model Successfully Fitted: 45
## ℹ Fitting Model: 46
## ✔ Model Successfully Fitted: 46
## ℹ Fitting Model: 47
## ✔ Model Successfully Fitted: 47
## ℹ Fitting Model: 48
## ✔ Model Successfully Fitted: 48
## ℹ Fitting Model: 49
## ✔ Model Successfully Fitted: 49
## ℹ Fitting Model: 50
## ✔ Model Successfully Fitted: 50
## Total time | 22.065 seconds
# 评估模型在测试集上的准确性
model_parallel_tbl %>%
modeltime_calibrate(new_data = df_test) %>% # 校准模型
modeltime_accuracy() %>% # 计算准确性指标
table_modeltime_accuracy() # 以表格形式展示结果
## Warning: There were 12 warnings in `dplyr::mutate()`.
## The first warning was:
## ℹ In argument: `.nested.col = purrr::map(...)`.
## Caused by warning:
## ! A correlation computation is required, but `estimate` is constant and has 0
## standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 11 remaining warnings.
# 选择表现最好的模型(这里假设是RECIPE_BOOST_TREE_27)
# 并在测试集上进行校准
calibration_tbl <-
model_parallel_tbl %>%
filter(.model_desc == "RECIPE_BOOST_TREE_27") %>%
modeltime_calibrate(df_test)
# 生成预测并绘制预测区间
calibration_tbl %>%
modeltime_forecast(
new_data = df_test, # 在测试集上进行预测
actual_data = df_merged %>%
filter(date >= as.Date("2024-07-01")) # 实际数据从2024-07-01开始
) %>%
plot_modeltime_forecast(
.interactive = FALSE, # 生成静态图表
.legend_show = FALSE, # 不显示图例
.line_size = 1.5, # 设置线条粗细
.color_lab = "", # 设置颜色标签
.title = "NASDAQ 100" # 设置图表标题
) +
# 添加副标题(使用markdown格式)
labs(subtitle = "<span style = 'color:dimgrey;'>预测区间</span><br><span style = 'color:red;'>机器学习模型</span>") +
# 设置x轴(日期)格式
scale_x_date(
expand = expansion(mult = c(.1, .15)), # 设置x轴扩展
labels = scales::label_date(format = "%b'%y") # 设置日期标签格式
) +
# 设置y轴(纳斯达克指数)格式为货币格式
scale_y_continuous(labels = scales::label_currency()) +
# 设置图表主题
theme_minimal(base_family = "Roboto Slab", base_size = 20) +
theme(
legend.position = "none", # 不显示图例
plot.background = element_rect(fill = "azure", color = "azure"), # 设置背景颜色
plot.title = element_text(face = "bold"), # 设置标题字体为粗体
axis.text = element_text(face = "bold"), # 设置坐标轴标签字体为粗体
plot.subtitle = ggtext::element_markdown(face = "bold") # 设置副标题为markdown格式并加粗
)
## Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
## -Inf
四、竞争算法
XGBoost 的主要竞争对手包括:
1. LightGBM
开发者:微软(2017 年发布)。
特点:
- 基于直方图的决策树构建,速度比 XGBoost 更快。
- 支持 Leaf-wise(按叶子生长)的树生长策略,而不是 XGBoost 的 Level-wise(按层生长)。
- 内存占用更小,适合大规模数据集。
- 适用场景:超大规模数据集、需要快速迭代的场景。
2. CatBoost
开发者:Yandex(2017 年发布)。
特点:
- 自动处理类别特征(无需手动编码)。
- 采用对称树结构,减少过拟合风险。
- 支持 GPU 加速和分布式训练。
适用场景:包含大量类别特征的数据集(如推荐系统、广告点击预测)。
3. Random Forest
特点:
- 基于 Bagging 思想的集成学习算法,并行训练多棵决策树。
- 对异常值和过拟合更鲁棒。
- 计算复杂度较低,但预测精度可能低于 Boosting 算法。
适用场景:快速 baseline 模型、对解释性要求较高的场景。
4. Neural Networks(神经网络)
特点:
- 适用于非结构化数据(如图像、文本、语音)。
- 需要大量数据才能表现良好。
- 计算资源消耗大,训练时间长。
适用场景:深度学习擅长的领域(如计算机视觉、自然语言处理)。
五、算法选择建议
- 数据规模较小:可尝试 Random Forest 或 CatBoost(处理类别特征更方便)。
- 数据规模中等:XGBoost 和 LightGBM 均可,XGBoost 更成熟,LightGBM 速度更快。
- 包含大量类别特征:优先使用 CatBoost。
- 追求极致速度:选择 LightGBM。
- 非结构化数据:考虑神经网络(如 CNN、Transformer)。
XGBoost 的优势在于其高效性、灵活性和广泛的应用场景,尤其在 Kaggle 竞赛和工业界的结构化数据分析中表现突出。