Sampling without replacement from a given non-uniform distribution in TensorFlow

前端 未结 2 466
遥遥无期
遥遥无期 2021-01-12 04:42

I\'m looking for something similar to numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])
in TensorFlow.

The closest Op

相关标签:
2条回答
  • 2021-01-12 05:32

    You could just use tf.py_func to wrap numpy.random.choice and make it available as a TensorFlow op:

    a = tf.placeholder(tf.float32)
    size = tf.placeholder(tf.int32)
    replace = tf.placeholder(tf.bool)
    p = tf.placeholder(tf.float32)
    
    y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32)
    
    with tf.Session() as sess:
        print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    

    You can specify the numpy seed as usual:

    np.random.seed(1)
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    np.random.seed(1)
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    np.random.seed(1)
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
    

    would print:

    [ 2.  0.]
    [ 2.  1.]
    [ 0.  1.]
    [ 2.  0.]
    [ 2.  1.]
    [ 0.  1.]
    [ 2.  0.]
    
    0 讨论(0)
  • 2021-01-12 05:38

    Yes, there is. See here and here for some background information. The solution is:

    z = -tf.log(-tf.log(tf.random_uniform(tf.shape(p),0,1))) 
    _, indices = tf.nn.top_k(tf.log(p) + z, size)
    
    0 讨论(0)
提交回复
热议问题