0%

统计学习方法|EM算法原理详解与实现

EM算法,中文名叫期望最大算法,主要用来求解含有隐变量的概率模型的参数。在高斯混合模型(GMM)、隐马尔可夫模型(HMM)中都有应用(在HMM模型中叫做Baum-Welch算法)。本篇博客将对EM算法的原理以及高斯混合模型(GMM)进行详细地讲解,并对其采用python与scikit-learn库这两种方式进行实现。

为什么要有EM算法?

首先我们要了解为什么要有EM算法?对于普通的估计概率模型的参数,我们可以直接通过求他们的极大似然估计$P(y|\theta)$。譬如说,掷一枚硬币正面朝上的概率为$q$,连续掷5次,结果为:正面、正面、反面、正面、反面,求$q$是多少?那么很显然,其中$q$就是我们需要估计的参数,我们可以根据极大似然估计:$q=argmax_q \{q^3(1-q)^2\}$。我们只要求其求导,就可以得到,$q=\frac 35$。这相信,对熟悉MLE的同学来说,是非常容易的。但是,如果说加入隐变量呢?具体如下:

假设有三枚硬币A、B和C,A正面朝上的概率是$q_1$,B正面朝上的概率是$q_2$,C正面朝上的概率是$q_3$。当投掷A硬币正面朝上的时候,我们再投掷B,记录B的结果;当投掷A硬币反面朝上的时候,我们再投掷C,记录C的结果。假设只能观察到结果,中间是选择了B或者C,都不知道,重复五次实验,结果为:正面、正面、反面、正面、反面。求$q_1,q_2,q_3$?(例子来源于《统计学习方法》)

那么,非常显然,我们不能像估计普通的概率模型的参数的方式来估计三个硬币正面朝上的概率。因为,中间选择B或者C的结果我们并不知道,也就是根据观察到的结果,我们不知道每次A的结果是正面还是反面。这就说明问题中含有了隐变量。这个时候,就需要新的算法来估计像这种含有隐变量的概率模型的参数。所以,EM算法应运而生。

EM算法的推导

首先,我直接给出整个EM算法,然后再给出推导过程。

此外,收敛条件如下:

下面给出推导过程(推导过程可能与《统计学习方法》不一样,其实道理是一样的):

OK,那么,我们现在得到了如下式子:

那么,要让$L(\theta)$最大化,也就是让下界ELBO最大化即可。也就是说:

由于${P(Z|Y,\theta^{(i)})}$是一个常数,不影响$\theta^{(i+1)}$的更新,所以:

我们对其整理一下:

是不是与最开始给出的EM算法一摸一样了🤩?到此,推导完毕🎉~

还有另外一种推导方法,使用KL散度来做(个人认为这种推导过程会更好理解一些),具体推导过程如下:

有几个需要注意的地方:EM算法对初值敏感,也就是说取不同的初值$\theta^{(0)}$,那么结果会不一样;另外,EM算法不能保证全局最优。

EM算法的收敛性证明

其实,知道EM算法的推导就差不多了。但是,实际上,还有一个问题就是:这样迭代更新,是否能够保证对数似然函数$L(\theta)$能收敛呢?证明过程如下:

OK,证明完毕🎉~当然,EM算法的一个典型应用就是GMM模型,这个之后有时间再写一篇文章吧~

EM算法的实现

把模型实现一遍才算是真正的吃透了这个模型呀。在这里,我采取了python来实现EM算法,以及scikit-learn库来实现GMM模型。我的github里面可以下在到所有的代码,欢迎访问我的github,也欢迎大家star和fork。附上GitHub地址: 《统计学习方法》及常规机器学习模型实现。具体代码如下

EM算法的python实现

GMM模型的scikit-learn实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#coding:utf-8
#Author:codewithzichao
#Date:2020-1-10
#E-mail:lizichao@pku.edu.cn

import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.datasets import make_blobs
from sklearn.mixture import GaussianMixture

if __name__=="__main__":
X, y_true = make_blobs(n_samples=400, centers=4,
cluster_std=0.60, random_state=0)
X = X[:, ::-1] # 交换列是为了方便画图

gmm = GaussianMixture(n_components=4).fit(X)
labels = gmm.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis')
plt.show()
Would you like to buy me a cup of coffee☕️~