Generative Adversarial Networks
-
GAN框架
GAN框架是有两个对象(discriminator,generator)的对抗游戏。generator是一个生成器,generator产生来自和训练样本一样的分布的样本。discriminator是一个判别器,判别是真实数据还是generator产生的伪造数据。discriminator使用传统的监督学习技术进行训练,将输入分成两类(真实的或者伪造的)。generator训练的目标就是欺骗判别器。
游戏中的两个参与对象由两个函数表示,每个都是关于输入和参数的可微分函数。discriminator是一个以 x 作为输入和使用θ(D) 为参数的函数D,D(x)是指判断输入样本x是真实样本的概率 ;generator由一个以z为输入使用 θ(G) 为参数的函数G,G(z)是指输入样本z产生一个新的样本,这个新样本希望接近真实样本的分布。
discriminator与generator都用两个参与对象的参数定义的代价函数。discriminator希望仅控制住θ(D)情形下最小化 J(D)(θ(D), θ(G))。generator希望在仅控制θ(D) 情形下最小化 J(G)(θ(D),θ(G))。因为每个参与对象的代价依赖于其他参与对象的参数,但是每个参与对象不能控制别人的参数,这个场景其实更为接近一个博弈而非优化问题。优化问题的解是一个局部最小,这是参数空间的点其邻居有着不小于它的代价。而对一个博弈的解释一个纳什均衡。在这样的设定下,Nash 均衡是一个元组,( θ(D), θ(G)) 既是关于θ(D)的 J(D) 的局部最小值和也是关于θ(G)的 J(G) 局部最小值。
图 1 GAN两种场景
如图 1所示GAN有两种场景,第一种场景(左图),discriminator对象随机从样本集中取一个元素X作为输入,discriminator对象的目标是以真实样本X作为输入时,尽量判断D(x)为1;而第二种场景(右图),具有discriminator和generator两个对象的参与,generator对象以噪声变量z作为输入,然后产生一个样本G(z),discriminator对象以G(z)作为输入并尽量判断D(G(z) )为0;而generator对象的目标是尽量让discriminator对象计算D(G(z) )为1。最后这个游戏是达到纳什均衡(Nash equilibrium),即G(z)产生的数据样本分布与真实数据样本分布一样,即对于所有的输入x,D(x) 的计算结果为0.5。
-
ANN函数
GAN是由一个判别模型(discriminator)和生成模(generator)型组成。其中discriminator和generator可以由任何可微函数来描述,如图 4所示是采用两个多层的神经网络来描述discriminator和generator模型,即图中的G和D函数。
图 2
generator是一个可微分函数 G。当 z 从某个简单的先验分布中采样出来时,G(z) 产生一个从 pmodel 中的样本。一般来说, GAN对于generator神经网络只有很少的限制。如果我们希望 pmodel 是 x 空间的支集(support),我们需要 z 的维度需要至少和 x 的维度一样大,而且 G 必须是可微分的,但是这些其实就是仅有的要求了。
-
损失函数
-
discriminator的代价
-
交叉熵[2]
交叉熵代价函数(Cross-entropy cost function)是用来衡量人工神经网络(ANN)的预测值与实际值的一种方式。交叉熵损失函数定义如下:
其中:
- x表示样本
- y表示样本x对应的标签
- a表示以样本x作为输入,模型的输出标签
- n表示样本的总数,当为二分类时n为2
-
目前为 GANs 设计的所有不同的博弈针对discriminator的 J(D) 使用了同样的代价函数。他们仅仅是generator J(G) 的代价函数不同。
discriminator的代价函数是:
其中: 表示在分布上的期望,D(x)为概率函数。
其实就是标准地训练一个sigmoid 输出的标准二分类器交叉熵代价函数。唯一的不同就是分类器在两个 minibatch 的数据上进行训练;一个来自数据集(其中的标签均是 1),另一个来自生成器(其标签均是 0)。
通过给discriminator模型定义损失函数后,将优化discriminator模型转移为优化等式,即训练discriminator模型就是为了最小化discriminator的等式。
-
Minimax
GAN框架有两个参与对象discriminator和generator ,上一节只考虑优化discriminator模型,还需要考虑优化generator模型。GAN使用了零和博弈思想为generator模型定义损失函数。在零和博弈游戏中,其所有参与人的代价总是 0,即在游戏中赢的得正数,输的得负数,所以总和为0。在零和博弈中,参加游戏双方的得分互为相反数,所以根据discriminator的损失函数,可推导出generator的损失函数为:
所以优化generator模型,一样是优化 损失函数,即最小化该损失函数。由于和两个损失函数只是互为相反数,所以可以将两个等式合并为一个优化等式。即
由于我们训练D来最大化分配正确标签给不管是来自于训练样例还是G生成的样例的概率.我们同时训练G来最小化。换句话说,D和G的训练是关于值函数V(G,D)的极小化极大的二人博弈问题:
其中:
- G表示生成模型,D表示分类模型
- x~pdata(x) 表示x取自训练数据的分布
- z~p(z) 表示z取自我们模拟数据的分布
图 3
如图 2所示a-b是模型G和D的优化过程,黑色的虚线表示训练数据的分布;绿色的实线表示模型G产生的分布;蓝色的虚线表示模型D的计算值;水平X轴表示D函数的计算值;水平z轴表示噪声值。一开始G的产生分布于真实数据分布偏离较大,且模型D对真实数据和伪造数据区分能力较强,即对真实数据D函数的计算值较大,而对伪造数据D函数的计算值较小,如图a;随着模型的训练,G数据分布于真实数据分布逐渐重合,如图d,最后D的计算值恒等为0.5。
-
训练过程
训练过程包含同时随机梯度下降 simultaneous SGD。在每一步,会采样两个 minibatch:一个来自数据集的 x 的 minibatch 和一个从隐含变量的模型先验采样的 z 的 minibatch。然后两个梯度步骤同时进行:一个更新 θ(D) 来降低 J(D),另一个更新 θ(G) 来降低 J(G)。这两个步骤都可以使用你选择的基于梯度的优化算法。
生成对抗网络的minibatch随机梯度下降训练。判别器的训练步数,k是一个超参数。在我们的试验中使用k=1,使消耗最小。
图 4
-
理论分析
GAN的设计思想采用discriminator和generator两个模型进行对抗优化,本章用两个证明来从理论上论证了对抗网络的合理性。
-
命题一:全局最优
命题:当G固定的时候,D会有唯一的最优解。真实描述如下:
证明如下:
-
首先,根据连续函数的期望计算方式,对V(G,D)进行变换:
-
对于任意的a,b ∈ R2 \ {0, 0}, 下面的式子在a/(a+b)处达到最优:
所以得证。
-
命题二:收敛性
命题:如果G和D有足够的性能,对于算法中的每一步,给定G时,判别器能够达到它的最优,并且通过更新pg来提高这个判别准则。
则pg收敛为pdata。
证明略,看不太懂。
-
CycleGAN[5]
-
概述
CycleGAN的原理可以概述为: 将一类图片转换成另一类图片 。也就是说,现在有两个样本空间X和Y,我们希望把X空间中的样本转换成Y空间中的样本。(获取一个数据集的特征,并转化成另一个数据集的特征).
图 5
-
形式化
CycleGAN模型的学习目标是训练两个映射函数:G:XàY和F:YàX,同时CycleGAN模型还包含了两个相关的discriminator对象:Dx和Dy。Dy是为了区分G函数产生的数据和Y数据;而Dx是为了区分F函数产生的数据和X数据,如图 5(a)所示。
-
对抗损失函数
如3.2小节所示介绍的对抗网络,对于一个映射函数G:XàY,和discriminator对象DY,则GAN的损失函数定义为:
其中,映射函数G是将X领域的数据转换为类似Y领域的数据,而DY就是判别真实的Y数据和G伪造的Y数据。即GAN的优化目标是:。同样的对于映射F:YàX,和discriminator对象DX,可以定义一个GAN损失函数的优化目标:.
-
循环一致损失函数
理论上GAN能够学习两个映射函数G和F,其能够分别从X或Y一个领域的数据生成到另一个领域的数据。但是由于映射函数变换可能性非常多,无法保证映射函数能够将一个领域的输入数据xi转换为其它领域的数据yi。为了减少映射函数的变换范围或可能性,CycleGAN增加了一些约束函数来限制这种变换范围过大的问题。
如图 5(b)所示,通过映射函数G和F,可以从X领域的数据样本变换为领域Y的数据样本,再变换为X领域的数据样本,从而生成一个环,即:,同理有图 5(c)的。所以原始数据样本x和循环产生的数据F(G(x))之间肯定有差异,那么可以定义一致性损失函数为:
其中式中的方括号是使用了L1规范化。
-
完整表达式
综上所述,CycleGAN的损失函数可以完整表达为:
其中控制了映射函数G和F的相对重要性。所以CycleGAN的优化目标是:
其中G和F两个映射函数的内部结构互相彼此独立,即它们能将一个数据样本映射到另一个领域的数据样本。
-
实现
CycleGAN网络的实现就是定义四个神经网络:G、F、Dx和Dy;然后优化这个最终的表达式,
-
参考文献
-
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
-
来源:oschina
链接:https://my.oschina.net/u/4401288/blog/3583188