由于需要,最近开始看强化学习的东西啦,大概一周时间搞完DRL基础,然后开始探索DRL在NLP中的应用。🤩这篇博客主要是记录关于GAN的原理、其背后的的理论以及其各种变体模型(especially for NLP)的内容,感觉还是挺有意思的。
生成对抗网络(GAN)
目前来看,GAN模型不是一个新鲜的技术了,它在神经风格迁移,人脸变换等等方面都有很大的应用。先来讲解什么是GAN。
本篇博客的大部分内容来源于李宏毅教授的GANs课程。
在GAN中,有两个部分:Generator与Discriminator。原论文使用的是银行与印钞团伙的例子。在这里,我用一个比较peace的比喻就是:Generator好比于学生,它输出的是学生自己完成的作业,DIscriminator好比于老师,来判断学生的作业是否达到标准要求。将标准的作业与学生的作业都做为Discriminator的输入,老师希望自己判别标准作业以及学生的作业的能力越来越强,而学生则希望自己完成的作业越来越靠近标准的作业,以此来通过老师的判定。最终,当老师无法判定标准作业以及学生的作业之间的区别的时候,就说明学生自己的作业就达到了标准作业的要求,通过了老师的判定。
具体来说,Generator用来生成”假样本“,将“假样本”与真实样本同时输入Discriminator中,Discriminator需要判断哪些是真实样本,哪些是“假样本“。如果是真实样本,我们给它记为1,如果是”假样本“,我们记为0。这样不断循环,Generator不断生成更加能够“骗过”Discriminator的“假样本”,而Discriminator则不断提高自己识别真实样本与”假样本“的能力,最终,知道Discriminator无法辨别真实样本与”假样本“的时候,就结束。我们最终使用的是Generator。下面给出公式化的表达,如下:
GAN的原理
放图~
原始的GAN实际上就是在minimizeGenerator生成的data的分布$P_G$与真实的data分布$P_{data}$的JS divergence。下面具体分析一下~
- for Discriminator:它的作用是衡量得到$P_G$与$P_{data}$的JS divergence。它优化的目标是要maximize真实data的概率,同时minimize Gennerator生成的data的概率,这实际上就是和LR是一样的。我们可以将真实data看到positive sample,Generator生成的data看作negative sample,可以看到与LR是等价的。所以,Discriminator的loss就等价于 minimize CE。用公式来表示的话,就是:
- For Generator:它的作用是minimize 得到的JS divergence。用公式来表示就是:
另外,我们在更新Generator的时候,我们会使用$V=\frac 1m\sum_{i=1}^{m}-log(D(G(z_i)))$,因为相比于$log(1-D(G(z_i)))$,在训练的开始,梯度下降的比较快。
fGAN——任意的Divergence
在原始的GAN中,我们是去衡量并minimize Generator生成的data的分布与真实data的分布的JS divergence,那我们很容易想到以下两个问题:
- GAN中是否可以换成的其他的divergence呢?
- 换成的其他的divergence,是否会有提高呢?也就是说JS divergence是否是最优的divergence呢?
对于第一个问题,答案是可以的!结果就是这一小节讲的f-divergence;对于第二个问题,答案是:不知道,因为换用其他的divergence,效果似乎没有更好,但是也没有更差,很多GAN的变体中用到了其他的divergence,之后会讨论。
WGAN
VAE-GAN
BIGAN
GAN的实现
实现代码
在这里,我采用tensorflow2来实现DCGAN模型。需要注意:在Generator部分,我们是需要从低维转到高维,所以使用反卷积操作,但是需要注意的是:如果使用反卷积,势必要使用填充,这样的话,如果我们要生成满意的图片的话,必须要保留位置信息,所以,在实现的时候,我们要抛弃会损失位置信息的池化层以及最后的全连接层,此外,我们要使用多个反卷积,从而使得宽度与高度不断变大,深度不断变小,从而达到我们预期的维度,除此之外,由于叠加多层反卷积层,所以我们需要添加BatchNorm,来改善模型的训练效果,同时我们还需要是用leakyRelu,来使得梯度保持稳定。对于Disciriminator,我们使用多个卷积层,其他的与Generator一样,注意,我们最后我们仍然不使用全连接层,而是扁平化后直接送入sigmoid层。具体如下:
首先是模型定义部分:
1 | import tensorflow as tf |
最后是训练过程:
1 | # coding: utf-8 |
结果
这是第一次epoch训练结束后,使用测试数据得到的结果:
这是第五次epoch训练结束后,使用测试数据得到的结果:
可以看到结果有很大的提升,当然了,我这里只训练了5个epochs,所以最后的结果也不是很好,可以试着训练几千轮,应该效果会好很多。到此为止,基础的GAN的知识就讲到这。其实GAN实现起来简单,但是GAN存在的很多的问题,需要关注,关于GAN背后的理论知识还是需要多去了解一点,才能够真正的用到其他的领域,譬如:Discriminator的本质是得到generator distribution与real distribution之间的某种divergence(原始的GAN是得到JS divergence),而Generator的本质是去minimize 得到的divergence的值等等,这些理论都是需要去了解的,包括GAN需要注意的地方,譬如:优化饱和,mode collapse,model dropping问题。之后有时间再更新关于GAN的部分吧~