回顾InfoGAN与隐变量

佐手、 提交于 2020-02-08 02:25:40

一、代码

1. MNIST数据集的引变量c

dataset = MnistDataset()#数据集MNIST

    latent_spec = [

        (Uniform(62), False),#62类默认是false

        (Categorical(10), True),#离散码c1

        (Uniform(1, fix_std=True), True),#连续码c2

        (Uniform(1, fix_std=True), True),#连续码c3

]

隐变量是由列表构成(列表是由一系列按特定顺序排列的元素组成,用[]来表示,并用逗号分隔其中的元素})

2. 互信息的参数定义

互信息计算起始于如下两个变量:

reg_z:表示了模型开始随机生成的隐变量

fake_ref_z_dist_info:表示了经过Encoder计算后的隐变量分布信息

根据连续型和离散型的分类,两个变量分成了以下四个变量:

cont_reg_z:reg_z的连续变量部分

cont_reg_dist_info:fake_ref_z_dist_info的连续变量部分

disc_reg_z:reg_z的离散变量部分

disc_reg_dist_info:fake_ref_z_dist_info的离散变量部分

四个变量两两组队完成了后验公式P(c)logQ(c|g?([c,z])的计算:

cont_log_q_c_given_x:连续变量的后验

disc_log_q_c_given_x:离散变量的后验

同时,输入的隐变量也各自完成先验P(c)logP(c)的计算:

cont_log_q_c:连续变量的先验

disc_log_q_c:离散变量的先验

由于上面的运算全部是元素级的计算,还要把向量求出的内容汇总,得到∑P(c)logQ(c|g?([c,z])和∑P(c)logP(c):

cont_cross_ent:连续变量的交叉熵

cont_ent:连续变量的熵

disc_cross_ent:离散变量的交叉熵

disc_ent:离散变量的熵

跟据互信息公式两两相减,得到各自的互信息损失:

cont_mi_est:连续变量的互信息

disc_mi_est:离散变量的互信息

最后将两者相加就得到了最终的互信息损失。

二、隐变量c如何进行可解释表达

(http://blog.csdn.net/u011699990/article/details/71599067)

1. How can we achieve unsupervised learning of disentangled representation?

 

通常,无监督学习学到的特征是混杂在一起的,如上图所示,这些特征在数据空间中以一种复杂的无序的方式进行编码,但是如果这些特征是可分解的,那么这些特征将具有更强的可解释性,将更容易的利用这些特征进行编码。

2. 标准GAN

 

如果从表征学习的角度来看GAN模型,由于在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3。

3. 网络结构图

 

给生成器输入隐含编码c和噪声z,生成假的数据,从假数据和真实数据中随机采样,输入给定D进行判断,是真还是假。Q通过与D共享卷积层,可以减少计算花销。在这里,Q是一个变分分布,在神经网络中直接最大化,Q也可以视作一个判别器,输出类别c。

 

4. 实验

 

使用MNIST数据集,使用了三个隐含编码,c1用十个离散数字进行编码,每个类别的概率都是0.1,c2,c3连续编码,是-2到2的均匀分布。

通过实验发现,c1可以作为一个分类器,分类的错误率为5%,图片a中第二行将7识别为9,但是不是意味着c10-9分别代表着生成数字的0-9,这是为了可视化效果,对数据重新排序的结果。如果在常规的GAN模型中添加c1编码,发现生成的图片与c1没有明显的关联。

事实上,离散码是时的表现,C1也可以作为一个分类器,通过将c1中的每个类别与一个数字类型进行匹配。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!