In Tensorflow, how to use tf.gather() for the last dimension?

后端 未结 8 1905
余生分开走
余生分开走 2021-01-04 02:54

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,

相关标签:
8条回答
  • 2021-01-04 03:08

    As of TensorFlow 1.3 tf.gather has an axis parameter, so the various workarounds here are no longer necessary.

    https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223

    0 讨论(0)
  • 2021-01-04 03:11

    You can try this way, for instance(in most cases in NLP at the least),

    The parameter is of shape [batch_size, depth] and the indices are [i, j, k, n, m] of which the length is batch_size. Then gather_nd can be helpful.

    parameters = tf.constant([
                              [11, 12, 13], 
                              [21, 22, 23], 
                              [31, 32, 33], 
                              [41, 42, 43]])    
    targets = tf.constant([2, 1, 0, 1])    
    batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])     
    indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number   
    items = tf.gather_nd(parameters, indices)  
    # which is what we want: [13, 22, 31, 42]
    

    This snippet first find the fist dimension through the batch_num and then fetch the item along that dimension by the target number.

    0 讨论(0)
  • 2021-01-04 03:12

    There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

    For now you can:

    1. transpose your matrix so that dimension to gather is first (transpose is expensive)

    2. reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

    3. use gather_nd. Will still need to turn your column indices into list of individual element indices.
    0 讨论(0)
  • 2021-01-04 03:16

    Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)

    Code:

    import tensorflow as tf
    import numpy as np
    
    shape = [2, 2, 2, 10] 
    L = np.arange(np.prod(shape))
    L = np.reshape(L, shape)
    
    indices = [0, 2, 3, 8]
    axis = -1 # last dimension
    
    def gather_axis(params, indices, axis=0):
        return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
    
    print(L)
    with tf.Session() as sess:
        partL = sess.run(gather_axis(L, indices, axis))
        print(partL)
    

    Result:

    L = 
    [[[[ 0  1  2  3  4  5  6  7  8  9]
       [10 11 12 13 14 15 16 17 18 19]]
    
      [[20 21 22 23 24 25 26 27 28 29]
       [30 31 32 33 34 35 36 37 38 39]]]
    
    
     [[[40 41 42 43 44 45 46 47 48 49]
       [50 51 52 53 54 55 56 57 58 59]]
    
      [[60 61 62 63 64 65 66 67 68 69]
       [70 71 72 73 74 75 76 77 78 79]]]]
    
    partL = 
    [[[[ 0  2  3  8]
       [10 12 13 18]]
    
      [[20 22 23 28]
       [30 32 33 38]]]
    
    
     [[[40 42 43 48]
       [50 52 53 58]]
    
      [[60 62 63 68]
       [70 72 73 78]]]]
    
    0 讨论(0)
  • 2021-01-04 03:20

    A correct version of @Andrei's answer would read

    cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
    result = tf.gather_nd(matrix, cat_idx)
    
    0 讨论(0)
  • 2021-01-04 03:21

    Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7

    import tensorflow as tf
    import numpy as np
    x = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])
    idx = tf.constant([1, 0, 2])
    idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
    y = tf.gather(tf.reshape(x, [-1]),  # flatten input
                  idx_flattened)  # use flattened indices
    
    with tf.Session(''):
      print y.eval()  # [2 4 9]
    
    0 讨论(0)
提交回复
热议问题