tensorflow equivalent of torch.gather

前端 未结 3 1084
花落未央
花落未央 2021-01-16 05:44

I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim

3条回答
  •  鱼传尺愫
    2021-01-16 06:20

    For 2D case,there is a method to do it:

    # a.shape (16L, 10L)
    # idx.shape (16L,1)
    idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
    b = tf.gather_nd(a,idx)
    

    However,For ND case,this method maybe very complex

提交回复
热议问题