TensorFlow getting elements of every row for specific columns

后端 未结 5 722
再見小時候
再見小時候 2021-02-08 02:57

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable



        
5条回答
  •  孤街浪徒
    2021-02-08 03:58

    We can do the same using this combination of map_fn and gather_nd.

    def get_element(a, indices):
        """
        Outputs (ith element of indices) from (ith row of a)
        """
        return tf.map_fn(lambda x: tf.gather_nd(x[0], x[1]), 
                                      (a, indices),
                                      dtype = tf.float32)
    

    Here's an example usage.

    A = tf.constant(np.array([[1,2,3],
                              [4,5,6],
                              [7,8,9]], dtype = np.float32))
    
    idx = tf.constant(np.array([[2],[1],[0]]))
    elems = get_element(A, idx)
    
    with tf.Session() as sess:
        e = sess.run(elems)
    
    print(e)
    

    I don't know if this will be much slower than other answers.

    It has the advantage that you don't need to specify the number of rows of A in advance, as long as a and indices have the same number of rows at runtime.

    Note the output of the above will be rank 1. If you'd prefer it to have rank 2, replace gather_nd by gather

提交回复
热议问题