I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,
As of TensorFlow 1.3 tf.gather
has an axis
parameter, so the various workarounds here are no longer necessary.
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
You can try this way, for instance(in most cases in NLP at the least),
The parameter is of shape [batch_size, depth]
and the indices are [i, j, k, n, m] of which the length is batch_size. Then gather_nd
can be helpful.
parameters = tf.constant([
[11, 12, 13],
[21, 22, 23],
[31, 32, 33],
[41, 42, 43]])
targets = tf.constant([2, 1, 0, 1])
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number
items = tf.gather_nd(parameters, indices)
# which is what we want: [13, 22, 31, 42]
This snippet first find the fist dimension through the batch_num and then fetch the item along that dimension by the target number.
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
For now you can:
transpose your matrix so that dimension to gather is first (transpose is expensive)
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
gather_nd
. Will still need to turn your column indices into list of individual element indices.Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)
Code:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
Result:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]
A correct version of @Andrei's answer would read
cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7
import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]