tensorflow\'s site gives this example
tf.random.categorical(tf.log([[10., 10.]]), 5)
produces a tensor that \"has shape [1, 5], where each
As you note, tf.random.categorical takes two parameters:
logits
, a 2D float tensor with shape [batch_size, num_classes]
num_samples
, an integer scalar.The output is a 2D integer tensor with shape [batch_size, num_samples]
.
Each "row" of the logits
tensor (logits[0, :]
, logits[1, :]
, ...) represents the event probabilities of a different categorical distribution. The function does not expect actual probability values, though, but unnormalized log-probabilities; so the actual probabilities would be softmax(logits[0, :])
, softmax(logits[1, :])
, etc. The benefit of this is that you can give basically any real values as input (e.g. the output of a neural network) and they will be valid. Also, it's trivial to use specific probability values, or proportions, using logarithms. For example, both [log(0.1), log(0.3), log(0.6)]
and [log(1), log(3), log(6)]
represent the same probability, where the second class is three times as likely as the first one but only half as likely as the third one.
For each row of (unnormalized log-)probabilities, you get num_samples
samples from the distribution. Each sample is an integer between 0
and num_classes - 1
, drawn according to the given probabilities. So, the result is the 2D tensor with shape [batch_size, num_samples]
with the sampled integers for each distribution.
EDIT: A small example of the function.
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(123)
logits = tf.log([[1., 1., 1., 1.],
[0., 1., 2., 3.]])
num_samples = 30
cat = tf.random.categorical(logits, num_samples)
print(sess.run(cat))
# [[3 3 1 1 0 3 3 0 2 3 1 3 3 3 1 1 0 2 2 0 3 1 3 0 1 1 0 1 3 3]
# [2 2 3 3 2 3 3 3 2 2 3 3 2 2 2 1 3 3 3 2 3 2 2 1 3 3 3 3 3 2]]
In this case, the result is an array with two rows and 30 columns. The values in the first row are sampled from a categorical distribution where every class ([0, 1, 2, 3]
) has the same probability. In the second row, the class 3
is the most likely one, and class 0
has just no probability of being sampled.