If A
is a TensorFlow variable like so
A = tf.Variable([[1, 2], [3, 4]])
and index
is another variable
We can do the same using this combination of map_fn
and gather_nd
.
def get_element(a, indices):
"""
Outputs (ith element of indices) from (ith row of a)
"""
return tf.map_fn(lambda x: tf.gather_nd(x[0], x[1]),
(a, indices),
dtype = tf.float32)
Here's an example usage.
A = tf.constant(np.array([[1,2,3],
[4,5,6],
[7,8,9]], dtype = np.float32))
idx = tf.constant(np.array([[2],[1],[0]]))
elems = get_element(A, idx)
with tf.Session() as sess:
e = sess.run(elems)
print(e)
I don't know if this will be much slower than other answers.
It has the advantage that you don't need to specify the number of rows of A in advance, as long as a
and indices
have the same number of rows at runtime.
Note the output of the above will be rank 1. If you'd prefer it to have rank 2, replace gather_nd
by gather