In Tensorflow, how to use tf.gather() for the last dimension?

后端 未结 8 1902
余生分开走
余生分开走 2021-01-04 02:54

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,

8条回答
  •  执笔经年
    2021-01-04 03:16

    Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)

    Code:

    import tensorflow as tf
    import numpy as np
    
    shape = [2, 2, 2, 10] 
    L = np.arange(np.prod(shape))
    L = np.reshape(L, shape)
    
    indices = [0, 2, 3, 8]
    axis = -1 # last dimension
    
    def gather_axis(params, indices, axis=0):
        return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
    
    print(L)
    with tf.Session() as sess:
        partL = sess.run(gather_axis(L, indices, axis))
        print(partL)
    

    Result:

    L = 
    [[[[ 0  1  2  3  4  5  6  7  8  9]
       [10 11 12 13 14 15 16 17 18 19]]
    
      [[20 21 22 23 24 25 26 27 28 29]
       [30 31 32 33 34 35 36 37 38 39]]]
    
    
     [[[40 41 42 43 44 45 46 47 48 49]
       [50 51 52 53 54 55 56 57 58 59]]
    
      [[60 61 62 63 64 65 66 67 68 69]
       [70 71 72 73 74 75 76 77 78 79]]]]
    
    partL = 
    [[[[ 0  2  3  8]
       [10 12 13 18]]
    
      [[20 22 23 28]
       [30 32 33 38]]]
    
    
     [[[40 42 43 48]
       [50 52 53 58]]
    
      [[60 62 63 68]
       [70 72 73 78]]]]
    

提交回复
热议问题