If A
is a TensorFlow variable like so
A = tf.Variable([[1, 2], [3, 4]])
and index
is another variable
After dabbling around for quite a while. I found two functions that could be useful.
One is tf.gather_nd()
which might be useful if you can produce a tensor
of the form [[0, 0], [1, 1]]
and thereby you could do
index = tf.constant([[0, 0], [1, 1]])
tf.gather_nd(A, index)
If you are unable to produce a vector of the form [[0, 0], [1, 1]]
(I couldn't produce this as the number of rows in my case was dependent on a placeholder) for some reason then the work around I found is to use the tf.py_func()
. Here is an example code on how this can be done
import tensorflow as tf
import numpy as np
def index_along_every_row(array, index):
N, _ = array.shape
return array[np.arange(N), index]
a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32)
index = tf.Variable([0, 1], dtype=tf.int32)
a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0]
session = tf.InteractiveSession()
a.initializer.run()
index.initializer.run()
a_slice = a_slice_op.eval()
a_slice
will be a numpy array [1, 4]