机器学习算法:EM算法

杀马特。学长 韩版系。学妹 提交于 2019-11-27 12:09:16

EM算法

适用问题:概率模型参数估计
模型特点:含隐变量的概率模型
学习策略:极大似然估计、极大后验概率估计
学习的损失函数:对数似然损失
学习算法:迭代算法

EM算法
给定的训练样本是,样例间独立,我们想找到每个样例隐含的类别z,能使得p(x,z)最大。p(x,z)的最大似然估计如下:

  第一步是对极大似然取对数,第二步是对每个样例的每个可能类别z求联合分布概率和。但是直接求一般比较困难,因为有隐藏变量z存在,但是一般确定了z后,求解就容易了。
  EM是一种解决存在隐含变量优化问题的有效方法。竟然不能直接最大化,我们可以不断地建立的下界(E步),然后优化下界(M步)。这句话比较抽象,看下面的。
  对于每一个样例i,让表示该样例隐含变量z的某种分布,满足的条件是。(如果z是连续性的,那么是概率密度函数,需要将求和符号换做积分符号)。比如要将班上学生聚类,假设隐藏变量z是身高,那么就是连续的高斯分布。如果按照隐藏变量是男女,那么就是伯努利分布了。

可以由前面阐述的内容得到下面的公式:

import numpy as np
#Numpy是Python的一个科学计算的库,提供了矩阵运算的功能
import math
pro_A, pro_B, por_C = 0.5, 0.5, 0.5
data=[1,1,0,1,0,0,1,0,1,1]
def pmf(i, pro_A, pro_B, pro_C):

迭代计算参数的估计值,直到收敛为止。

   pro_1 = pro_A * math.pow(pro_B, data[i]) * math.pow((1-pro_B), 1-data[i])  
   pro_2 = pro_A * math.pow(pro_C, data[i]) * math.pow((1-pro_C), 1-data[i])
    return pro_1 / (pro_1 + pro_2)

class EM:

初始化,self为实例,prob为参数值

    def __init__(self, prob):
        self.pro_A ,self.pro_B ,self.pro_C = prob
    def pmf(self, i):
        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1 - self.pro_B), 1 - data[i])
        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1 - self.pro_C), 1 - data[i])
        return pro_1*1.0 / (pro_1 + pro_2)

#fit() :用于从训练数据生成学习模型参数,为了数据归一化,求得训练集X的均值,方差,最大值,最小值啊这些训练集X固有的属性。可以理解为一个训练过程
def fit(self,data):
count = len(data)
print(‘init prob:{},{},{}’.format(self.pro_A,self.pro_B,self.pro_C))
for d in range(count):
#把一个函数改写为一个 generator 就获得了迭代能力
_ = yield
tmp_pmf = [self.pmf(k) for k in range(count)]
#计算模型参数的新估计值
pro_A = 1/count * sum(tmp_pmf)
pro_B = sum([tmp_pmf[k]*data[k] for k in range(count)]) / sum([ tmp_pmf[k] for k in range(count)])
pro_C = sum([(1-tmp_pmf[k]) * data[k] for k in range(count)]) / sum([(1-tmp_pmf[k]) for k in range(count)])
print(’{}/{} pro_A:{:3f},pro_B:{:3f},pro_C:{:3f}’.format(d+1,count,pro_A,pro_B,pro_C))
self.pro_A = pro_A
self.pro_B = pro_B
self.pro_C = pro_C

em = EM(prob=[0.5, 0.5, 0.5])
f = em.fit(data)
next(f)
f.send(1)
f.send(2)

在这里插入图片描述

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!