一、代码
1. MNIST数据集的引变量c
dataset = MnistDataset()#数据集MNIST
latent_spec = [
(Uniform(62), False),#62类默认是false
(Categorical(10), True),#离散码c1
(Uniform(1, fix_std=True), True),#连续码c2
(Uniform(1, fix_std=True), True),#连续码c3
]
隐变量是由列表构成(列表是由一系列按特定顺序排列的元素组成,用[]来表示,并用逗号分隔其中的元素})
2. 互信息的参数定义
互信息计算起始于如下两个变量:
reg_z:表示了模型开始随机生成的隐变量
fake_ref_z_dist_info:表示了经过Encoder计算后的隐变量分布信息
根据连续型和离散型的分类,两个变量分成了以下四个变量:
cont_reg_z:reg_z的连续变量部分
cont_reg_dist_info:fake_ref_z_dist_info的连续变量部分
disc_reg_z:reg_z的离散变量部分
disc_reg_dist_info:fake_ref_z_dist_info的离散变量部分
四个变量两两组队完成了后验公式P(c)logQ(c|g?([c,z])的计算:
cont_log_q_c_given_x:连续变量的后验
disc_log_q_c_given_x:离散变量的后验
同时,输入的隐变量也各自完成先验P(c)logP(c)的计算:
cont_log_q_c:连续变量的先验
disc_log_q_c:离散变量的先验
由于上面的运算全部是元素级的计算,还要把向量求出的内容汇总,得到∑P(c)logQ(c|g?([c,z])和∑P(c)logP(c):
cont_cross_ent:连续变量的交叉熵
cont_ent:连续变量的熵
disc_cross_ent:离散变量的交叉熵
disc_ent:离散变量的熵
跟据互信息公式两两相减,得到各自的互信息损失:
cont_mi_est:连续变量的互信息
disc_mi_est:离散变量的互信息
最后将两者相加就得到了最终的互信息损失。
二、隐变量c如何进行可解释表达
(http://blog.csdn.net/u011699990/article/details/71599067)
1. How can we achieve unsupervised learning of disentangled representation?
通常,无监督学习学到的特征是混杂在一起的,如上图所示,这些特征在数据空间中以一种复杂的无序的方式进行编码,但是如果这些特征是可分解的,那么这些特征将具有更强的可解释性,将更容易的利用这些特征进行编码。
2. 标准GAN
如果从表征学习的角度来看GAN模型,由于在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3。
3. 网络结构图
给生成器输入隐含编码c和噪声z,生成假的数据,从假数据和真实数据中随机采样,输入给定D进行判断,是真还是假。Q通过与D共享卷积层,可以减少计算花销。在这里,Q是一个变分分布,在神经网络中直接最大化,Q也可以视作一个判别器,输出类别c。
4. 实验
使用MNIST数据集,使用了三个隐含编码,c1用十个离散数字进行编码,每个类别的概率都是0.1,c2,c3连续编码,是-2到2的均匀分布。
通过实验发现,c1可以作为一个分类器,分类的错误率为5%,图片a中第二行将7识别为9,但是不是意味着c1的0-9分别代表着生成数字的0-9,这是为了可视化效果,对数据重新排序的结果。如果在常规的GAN模型中添加c1编码,发现生成的图片与c1没有明显的关联。
事实上,离散码是时的表现,C1也可以作为一个分类器,通过将c1中的每个类别与一个数字类型进行匹配。
来源:CSDN
作者:蹦跶的小羊羔
链接:https://blog.csdn.net/yql_617540298/article/details/104212797