In Tensorflow, how to use tf.gather() for the last dimension?

后端 未结 8 1904
余生分开走
余生分开走 2021-01-04 02:54

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,

8条回答
  •  挽巷
    挽巷 (楼主)
    2021-01-04 03:20

    A correct version of @Andrei's answer would read

    cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
    result = tf.gather_nd(matrix, cat_idx)
    

提交回复
热议问题