How to use a tensor for indexing another tensor in tensorflow

后端 未结 1 396
清歌不尽
清歌不尽 2021-01-26 07:06

I have a data tensor of dimensios [B X N X 3], and I have an indices tensor of dimensions [B X M]. I wish to extract a

相关标签:
1条回答
  • 2021-01-26 07:45

    You can use tf.gather_nd() with code like this:

    import tensorflow as tf
    
    # B = 3
    # N = 4
    # M = 2
    # [B x N x 3]
    data = tf.constant([
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
        [[100, 101, 102], [103, 104, 105], [106, 107, 108], [109, 110, 111]],
        [[200, 201, 202], [203, 204, 205], [206, 207, 208], [209, 210, 211]],
        ])
    
    # [B x M]
    indices = tf.constant([
        [0, 2],
        [1, 3],
        [3, 2],
        ])
    
    indices_shape = tf.shape(indices)
    
    indices_help = tf.tile(tf.reshape(tf.range(indices_shape[0]), [indices_shape[0], 1]) ,[1, indices_shape[1]]);
    indices_ext = tf.concat([tf.expand_dims(indices_help, 2), tf.expand_dims(indices, 2)], axis = 2)
    new_data = tf.gather_nd(data, indices_ext)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print('data')
        print(sess.run(data))
        print('\nindices')
        print(sess.run(indices))
        print('\nnew_data')
        print(sess.run(new_data))
    

    new_data will be:

    [[[  0   1   2]
      [  6   7   8]]
    
     [[103 104 105]
      [109 110 111]]
    
     [[209 210 211]
      [206 207 208]]]
    
    0 讨论(0)
提交回复
热议问题