In Tensorflow, how to use tf.gather() for the last dimension?

后端 未结 8 1906
余生分开走
余生分开走 2021-01-04 02:54

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,

相关标签:
8条回答
  • 2021-01-04 03:22

    Implementing 2. from @Yaroslav Bulatov's:

    #Your indices
    indices = [0, 2, 3, 8]
    
    #Remember for final reshaping
    n_indices = tf.shape(indices)[0]
    
    flattened_L = tf.reshape(L, [-1])
    
    #Walk strided over the flattened array
    offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
    flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
    selected_rows = tf.gather(flattened_L, flattened_indices)
    
    #Final reshape
    partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))
    

    Credit to How to select rows from a 3-D Tensor in TensorFlow?

    0 讨论(0)
  • 2021-01-04 03:29

    With gather_nd you can now do this as follows:

    cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
    result = tf.gather_nd(matrix, cat_idx)
    

    Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:

    x = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])
    idx = tf.constant([1, 0, 2])
    idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
    y = tf.gather(tf.reshape(x, [-1]),  # flatten input
                  idx_flattened)  # use flattened indices
    
    with tf.Session(''):
      print y.eval()  # [2 4 9]
    

    The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).

    0 讨论(0)
提交回复
热议问题