TensorFlow - numpy-like tensor indexing

后端 未结 3 1532
猫巷女王i
猫巷女王i 2020-11-30 03:24

In numpy, we can do this:

x = np.random.random((10,10))
a = np.random.randint(0,10,5)
b = np.random.randint(0,10,5)
x[         


        
相关标签:
3条回答
  • 2020-11-30 03:40

    You can actually do that now with tf.gather_nd. Let's say you have a matrix m like the following:

    | 1 2 3 4 |
    | 5 6 7 8 |
    

    And you want to build a matrix r of size, let's say, 3x2, built from elements of m, like this:

    | 3 6 |
    | 2 7 |
    | 5 3 |
    | 1 1 |
    

    Each element of r corresponds to a row and column of m, and you can have matrices rows and cols with these indices (zero-based, since we are programming, not doing math!):

           | 0 1 |         | 2 1 |
    rows = | 0 1 |  cols = | 1 2 |
           | 1 0 |         | 0 2 |
           | 0 0 |         | 0 0 |
    

    Which you can stack into a 3-dimensional tensor like this:

    | | 0 2 | | 1 1 | |
    | | 0 1 | | 1 2 | |
    | | 1 0 | | 2 0 | |
    | | 0 0 | | 0 0 | |
    

    This way, you can get from m to r through rows and cols as follows:

    import numpy as np
    import tensorflow as tf
    
    m = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
    rows = np.array([[0, 1], [0, 1], [1, 0], [0, 0]])
    cols = np.array([[2, 1], [1, 2], [0, 2], [0, 0]])
    
    x = tf.placeholder('float32', (None, None))
    idx1 = tf.placeholder('int32', (None, None))
    idx2 = tf.placeholder('int32', (None, None))
    result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
    
    with tf.Session() as sess:
        r = sess.run(result, feed_dict={
            x: m,
            idx1: rows,
            idx2: cols,
        })
    print(r)
    

    Output:

    [[ 3.  6.]
     [ 2.  7.]
     [ 5.  3.]
     [ 1.  1.]]
    
    0 讨论(0)
  • 2020-11-30 03:45

    LDGN's comment is correct. This is not possible at the moment, and is a requested feature. If you follow issue#206 on github you'll get updated if/when this is available. Many people would like this feature.

    0 讨论(0)
  • 2020-11-30 03:52

    For Tensorflow 0.11, basic indexing has been implemented. More advanced indexing (like boolean indexing) is still missing but apparently is planned for future versions.

    Advanced indexing can be tracked with https://github.com/tensorflow/tensorflow/issues/4638

    0 讨论(0)
提交回复
热议问题