I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,
Implementing 2. from @Yaroslav Bulatov's:
#Your indices
indices = [0, 2, 3, 8]
#Remember for final reshaping
n_indices = tf.shape(indices)[0]
flattened_L = tf.reshape(L, [-1])
#Walk strided over the flattened array
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
selected_rows = tf.gather(flattened_L, flattened_indices)
#Final reshape
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))
Credit to How to select rows from a 3-D Tensor in TensorFlow?
With gather_nd you can now do this as follows:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).