multinomial pmf in python scipy/numpy

后端 未结 1 1928
醉梦人生
醉梦人生 2021-02-19 13:00

Is there a built-in function in scipy/numpy for getting the PMF of a Multinomial? I\'m not sure if binom generalizes in the correct way, e.g.

# Atte         


        
1条回答
  •  天命终不由人
    2021-02-19 13:49

    There's no built-in function that I know of, and the binomial probabilities do not generalize (you need to normalise over a different set of possible outcomes, since the sum of all the counts must be n which won't be taken care of by independent binomials). However, it's fairly straightforward to implement yourself, for example:

    import math
    
    class Multinomial(object):
      def __init__(self, params):
        self._params = params
    
      def pmf(self, counts):
        if not(len(counts)==len(self._params)):
          raise ValueError("Dimensionality of count vector is incorrect")
    
        prob = 1.
        for i,c in enumerate(counts):
          prob *= self._params[i]**counts[i]
    
        return prob * math.exp(self._log_multinomial_coeff(counts))
    
      def log_pmf(self,counts):
        if not(len(counts)==len(self._params)):
          raise ValueError("Dimensionality of count vector is incorrect")
    
        prob = 0.
        for i,c in enumerate(counts):
          prob += counts[i]*math.log(self._params[i])
    
        return prob + self._log_multinomial_coeff(counts)
    
      def _log_multinomial_coeff(self, counts):
        return self._log_factorial(sum(counts)) - sum(self._log_factorial(c)
                                                        for c in counts)
    
      def _log_factorial(self, num):
        if not round(num)==num and num > 0:
          raise ValueError("Can only compute the factorial of positive ints")
        return sum(math.log(n) for n in range(1,num+1))
    
    m = Multinomial([0.1, 0.1, 0.8])
    print m.pmf([4,4,2])
    
    >>2.016e-05
    

    My implementation of the multinomial coefficient is somewhat naive, and works in log space to prevent overflow. Also be aware that n is superfluous as a parameter, since it's given by the sum of the counts (and the same parameter set works for any n). Furthermore, since this will quickly underflow for moderate n or large dimensionality, you're better working in log space (logPMF provided here too!)

    0 讨论(0)
提交回复
热议问题