How to do weighted random sample of categories in python

后端 未结 9 2119

Given a list of tuples where each tuple consists of a probability and an item I\'d like to sample an item according to its probability. For example, give the list [ (.3, \'a\'),

9条回答
  •  囚心锁ツ
    2021-01-31 18:25

    There are hacks you can do if, for example, your probabilities fit nicely into percentages, etc.

    For example, if you're fine with percentages, the following will work (at the cost of a high memory overhead):

    But the "real" way to do it with arbitrary float probabilities is to sample from the cumulative distribution, after constructing it. This is equivalent to subdividing the unit interval [0,1] into 3 line segments labelled 'a','b', and 'c'; then picking a random point on the unit interval and seeing which line segment it it.

    #!/usr/bin/python3
    def randomCategory(probDict):
        """
            >>> dist = {'a':.1, 'b':.2, 'c':.3, 'd':.4}
    
            >>> [randomCategory(dist) for _ in range(5)]
            ['c', 'c', 'a', 'd', 'c']
    
            >>> Counter(randomCategory(dist) for _ in range(10**5))
            Counter({'d': 40127, 'c': 29975, 'b': 19873, 'a': 10025})
        """
        r = random.random() # range: [0,1)
        total = 0           # range: [0,1]
        for value,prob in probDict.items():
            total += prob
            if total>r:
                return value
        raise Exception('distribution not normalized: {probs}'.format(probs=probDict))
    

    One has to be careful of methods which return values even if their probability is 0. Fortunately this method does not, but just in case, one could insert if prob==0: continue.


    For the record, here's the hackish way to do it:

    import random
    
    def makeSampler(probDict):
        """
            >>> sampler = makeSampler({'a':0.3, 'b':0.4, 'c':0.3})
            >>> sampler.sample()
            'a'
            >>> sampler.sample()
            'c'
        """
        oneHundredElements = sum(([val]*(prob*100) for val,prob in probDict.items()), [])
        def sampler():
            return random.choice(oneHundredElements)
        return sampler
    

    However if you don't have resolution issues... this is actually probably the fastest way possible. =)

提交回复
热议问题