I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim
(16, 4096, 3)
(16, 32768, 3)
dim
For 2D case,there is a method to do it:
# a.shape (16L, 10L) # idx.shape (16L,1) idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1) b = tf.gather_nd(a,idx)
However,For ND case,this method maybe very complex