TensorFlow getting elements of every row for specific columns

后端 未结 5 707
再見小時候
再見小時候 2021-02-08 02:57

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable



        
5条回答
  •  攒了一身酷
    2021-02-08 03:45

    After dabbling around for quite a while. I found two functions that could be useful.

    One is tf.gather_nd() which might be useful if you can produce a tensor of the form [[0, 0], [1, 1]] and thereby you could do

    index = tf.constant([[0, 0], [1, 1]])

    tf.gather_nd(A, index)

    If you are unable to produce a vector of the form [[0, 0], [1, 1]](I couldn't produce this as the number of rows in my case was dependent on a placeholder) for some reason then the work around I found is to use the tf.py_func(). Here is an example code on how this can be done

    import tensorflow as tf 
    import numpy as np 
    
    def index_along_every_row(array, index):
        N, _ = array.shape 
        return array[np.arange(N), index]
    
    a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32)
    index = tf.Variable([0, 1], dtype=tf.int32)
    a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0]
    session = tf.InteractiveSession()
    
    a.initializer.run()
    index.initializer.run()
    a_slice = a_slice_op.eval() 
    

    a_slice will be a numpy array [1, 4]

提交回复
热议问题