tensorflow equivalent of torch.gather

前端 未结 3 1085
花落未央
花落未央 2021-01-16 05:44

I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim

相关标签:
3条回答
  • 2021-01-16 06:08

    For the last-axis gathering, we can use the 2D-reshape trick for general ND cases, and then employ @LiShaoyuan 2D code above

            # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
            def torch_gather(param, id_tensor):
    
                # 2d-gather torch equivalent from @LiShaoyuan above 
                def gather2d(target, id_tensor):
                    idx = tf.stack([tf.range(tf.shape(id_tensor)[0]),id_tensor[:,0]],axis=-1)
                    result = tf.gather_nd(target,idx)
                    return tf.expand_dims(result,axis=-1)
    
                target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
                target_shape = id_tensor.shape
    
                id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
                result = gather2d(target, id_tensor)
                return tf.reshape(result, target_shape)
    
    0 讨论(0)
  • 2021-01-16 06:19

    This "should" be a general solution using tf.gather_nd (I've only tested for rank 2 and 3 tensors along the last axis):

    def torch_gather(x, indices, gather_axis):
        # if pytorch gather indices are
        # [[[0, 10, 20], [0, 10, 20], [0, 10, 20]],
        #  [[0, 10, 20], [0, 10, 20], [0, 10, 20]]]
        # tf nd_gather needs to be
        # [[0,0,0], [0,0,10], [0,0,20], [0,1,0], [0,1,10], [0,1,20], [0,2,0], [0,2,10], [0,2,20],
        #  [1,0,0], [1,0,10], [1,0,20], [1,1,0], [1,1,10], [1,1,20], [1,2,0], [1,2,10], [1,2,20]]
    
        # create a tensor containing indices of each element
        all_indices = tf.where(tf.fill(indices.shape, True))
        gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
    
        # splice in our pytorch style index at the correct axis
        gather_indices = []
        for axis in range(len(indices.shape)):
            if axis == gather_axis:
                gather_indices.append(gather_locations)
            else:
                gather_indices.append(all_indices[:, axis])
    
        gather_indices = tf.stack(gather_indices, axis=-1)
        gathered = tf.gather_nd(x, gather_indices)
        reshaped = tf.reshape(gathered, indices.shape)
        return reshaped
    
    0 讨论(0)
  • 2021-01-16 06:20

    For 2D case,there is a method to do it:

    # a.shape (16L, 10L)
    # idx.shape (16L,1)
    idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
    b = tf.gather_nd(a,idx)
    

    However,For ND case,this method maybe very complex

    0 讨论(0)
提交回复
热议问题