I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,
Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)
Code:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
Result:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]