0%

统计学习方法|决策树模型原理详解与实现

决策树模型一种基本的分类与回归方法,是shallow learning的Adaboost、XGBoost、Light GBM、catBoost等树模型的基础,对于理解这些模型大有裨益。这篇博客将详细地讲解基本的决策树模型,主要会侧重回归树的讲解,因为这是Adaboost、XGBoost、Light GBM、catBoost等树模型的核心组成部分。并采用python与scikit-learn来对其进行实现。

决策树模型介绍

决策树模型是比较简单的模型。它的三个核心问题是:特征选择、决策树的生成、决策树的剪枝

特征选择

所谓的特征选择,我们可以这么想:在没有构建决策树之前,只看训练数据集,是很混乱的,因为我们无法根据训练数据集,直接判断新的实例的类别,也就是说,训练集是没有分类能力的。那么,我们就需要构建一套规则,当我们应用这套规则的时候,我们能够得到实例的类别。那么,问题来了:我们怎么选择分类的特征,才能使得分类的效果最好呢?这就是特征选择需要做的事情。常见的应用与特征选择的准则有:信息增益与信息增益比。

信息增益

信息增益表示在已知特征$A$的条件下,从而使得数据集的不确定性减少的程度。看到不确定性,很自然地就会联想到熵!因为熵正是用来度量随机变量不确定性的程度。那么,下面给出熵的定义:

假设离散随机变量$X$的概率分布是:$P(X=i)=p_i,i=1,2,…,n$,那么随机变量$X$的熵如下:

在给定$X$的情况下,随机变量$Y$的条件概率分布$P(Y|X)$的条件熵如下:

当其中的概率是由数据估计(MLE)得到的时候,就称为经验条件熵。

在知道熵的概念之后,那么信息增益的定义如下:

其中,$D$表示训练数据集,$A$表示特征。即:特征A对于训练数据集$D$的信息增益就等于$D$的经验熵与给定特征$A$的情况下$D$的经验条件熵之差。那么对于决策树模型来说,特征选择的准则是:选择信息增益大的特征,因为信息增益大的特征具有更强的分类能力。具体过程如下:

信息增益比

那么有了信息增益,为啥还要有信息增益比呢?原因在于:使用信息增益准则会导致决策树会更加偏向于特征取值数目多的特征。因为,选取特征取值数目多的特征,会让训练集的信息增益增大,也就是整个训练集的不纯度降低。但是,这样以来,会导致构建的决策树模型容易过拟合。因此,就有了信息增益比。

决策树的生成

决策树的生成算法有3中:ID3、C4.5、CART。在这里,我将讲解ID3、C4.5。CART将单独讲,因为其是后来那些大火的集成模型的核心部分。ID3与C4.5其实差不多,但是它们之间的区别在于特征选择的准则不同:ID3使用了信息增益,C4.5使用了信息增益比。在这里,我放上《统计学习方法》中关于C4.5的算法过程。

决策树的剪枝

当我们构建好了决策树之后,我们会发现这样构建的决策树很容易发生过拟合。原因在于:我们在构建决策树的时候,尽可能地去拟合训练数据,从而得到了过于复杂的决策树。那么一种很自然的想法就是:对决策树进行剪枝,从而让决策树不那么复杂。当然,这样以来,就会使得决策树的准确率下降,所以,我们就需要在模型复杂度与对训练集的预测误差之间做一个tradeoff。我们所用的损失函数如下:

其中,$T$表示决策树,$\alpha$是参数,用来平衡模型复杂度与预测误差之间的关系。$C(T)$表示模型对训练数据的误差,$|T|$表示模型的复杂度。当$\alpha$越大,模型越简单。(我们可以这么记:当$\alpha$为0的时候,决策树是过拟合的,所以增大$\alpha$,会让决策树变得简单。)

CART

CART,全名叫作:分类与回归树。所以,正如名字一样,它既可以用于分类,也可以用于回归问题。在分类问题中,使用的特征选择的准则是:基尼指数(Gini)最小化;对于回归问题,使用的生成方法是:平方误差最小化。分类问题我就不介绍了,只介绍一下用于回归问题的回归树。

给定训练数据集$X=\{(x_1,y_1),(x_2,y_2),…,(x_N,y_N)\}$,那么回归问题就是我们要构造一个函数$f(x)$,能够使得训练数据集的MSE最小,即:

假设,我们将输入空间划分为$R_1,R_2,…,R_M$,并且在每一个区域有一个输出常数$c_m$。那么目标函数可以表示为:

而其中$c_m$的最优值为就是为该区域的均值,如下:

所以,回归树就可以表示为:

那么关键是,怎么对输入空间进行划分呢?方法:比那里所有的切分变量与切分点,找到使得MSE最小的划分。假设首先随机选择第$j$个特征做为划分变量,其对应的值为$s$,那么我们就讲空间划分为两个,如下:

接下来,我们需要通过如下函数$m(s)$,从而找到最优的划分变量与划分值,如下:

其中,$\hat c_m=ave(y_i|x_i\in R_m),m=1,2$。那么,遍历所有的变量,找到最小$m(s)$的$j,s,c_m$,那么就可以得到最后的回归树了。(在这里附上《统计学习方法》的图)。具体的计算例子,参考:GBDT计算,这里虽然讲解的是提升树,但是由于回归问题的提升树是以回归树做为基本分类器,所以其中也涉及到了回归树的构建过程。

OK,理论部分就讲完了🎉~

决策树模型的实现

把模型实现一遍才算是真正的吃透了这个模型呀。在这里,我采取了两种方式来实现决策树模型:python实现以及调用scikit-learn库来实现。我的github里面可以下载到所有的代码,欢迎访问我的github,也欢迎大家star和fork。附上GitHub地址: 《统计学习方法》及常规机器学习模型实现。具体代码如下:

python实现

scikit-learn实现

Would you like to buy me a cup of coffee☕️~