can anyone give a tiny example to explain the params of tf.random.categorical?

后端 未结 1 1353
别那么骄傲
别那么骄傲 2020-12-31 12:21

tensorflow\'s site gives this example

tf.random.categorical(tf.log([[10., 10.]]), 5)

produces a tensor that \"has shape [1, 5], where each

相关标签:
1条回答
  • 2020-12-31 12:53

    As you note, tf.random.categorical takes two parameters:

    • logits, a 2D float tensor with shape [batch_size, num_classes]
    • num_samples, an integer scalar.

    The output is a 2D integer tensor with shape [batch_size, num_samples].

    Each "row" of the logits tensor (logits[0, :], logits[1, :], ...) represents the event probabilities of a different categorical distribution. The function does not expect actual probability values, though, but unnormalized log-probabilities; so the actual probabilities would be softmax(logits[0, :]), softmax(logits[1, :]), etc. The benefit of this is that you can give basically any real values as input (e.g. the output of a neural network) and they will be valid. Also, it's trivial to use specific probability values, or proportions, using logarithms. For example, both [log(0.1), log(0.3), log(0.6)] and [log(1), log(3), log(6)] represent the same probability, where the second class is three times as likely as the first one but only half as likely as the third one.

    For each row of (unnormalized log-)probabilities, you get num_samples samples from the distribution. Each sample is an integer between 0 and num_classes - 1, drawn according to the given probabilities. So, the result is the 2D tensor with shape [batch_size, num_samples] with the sampled integers for each distribution.

    EDIT: A small example of the function.

    import tensorflow as tf
    
    with tf.Graph().as_default(), tf.Session() as sess:
        tf.random.set_random_seed(123)
        logits = tf.log([[1., 1., 1., 1.],
                         [0., 1., 2., 3.]])
        num_samples = 30
        cat = tf.random.categorical(logits, num_samples)
        print(sess.run(cat))
        # [[3 3 1 1 0 3 3 0 2 3 1 3 3 3 1 1 0 2 2 0 3 1 3 0 1 1 0 1 3 3]
        #  [2 2 3 3 2 3 3 3 2 2 3 3 2 2 2 1 3 3 3 2 3 2 2 1 3 3 3 3 3 2]]
    

    In this case, the result is an array with two rows and 30 columns. The values in the first row are sampled from a categorical distribution where every class ([0, 1, 2, 3]) has the same probability. In the second row, the class 3 is the most likely one, and class 0 has just no probability of being sampled.

    0 讨论(0)
提交回复
热议问题