how does tensorflow indexing work

前端 未结 1 1819
[愿得一人]
[愿得一人] 2020-12-03 03:40

I\'m having trouble understanding a basic concept with tensorflow. How does indexing work for tensor read/write operations? In order to make this specific, how can the follo

相关标签:
1条回答
  • 2020-12-03 04:07

    There's github issue #206 to support this nicely, meanwhile you have to resort to verbose work-arounds

    The first example can be done with tf.select that combines two same-shaped tensors by selecting each element from one or the other

    tf.reset_default_graph()
    row_indices = tf.constant([1, 1, 2])
    col_indices = tf.constant([0, 2, 3])
    x = tf.zeros((3, 4))
    sess = tf.InteractiveSession()
    
    # get list of ((row1, col1), (row2, col2), ..)
    coords = tf.transpose(tf.pack([row_indices, col_indices]))
    
    # get tensor with 1's at positions (row1, col1),...
    binary_mask = tf.sparse_to_dense(coords, x.get_shape(), 1)
    
    # convert 1/0 to True/False
    binary_mask = tf.cast(binary_mask, tf.bool)
    
    twos = 2*tf.ones(x.get_shape())
    
    # make new x out of old values or 2, depending on mask 
    x = tf.select(binary_mask, twos, x)
    
    print x.eval()
    

    gives

    [[ 0.  0.  0.  0.]
     [ 2.  0.  2.  0.]
     [ 0.  0.  0.  2.]]
    

    The second one could be done with scatter_update, except scatter_update only supports on linear indices and works on variables. So you could create a temporary variable and use reshaping like this. (to avoid variables you could use dynamic_stitch, see the end)

    # get linear indices
    linear_indices = row_indices*x.get_shape()[1]+col_indices
    
    # turn 'x' into 1d variable since "scatter_update" supports linear indexing only
    x_flat = tf.Variable(tf.reshape(x, [-1]))
    
    # no automatic promotion, so make updates float32 to match x
    updates = tf.constant([5, 4, 3], dtype=tf.float32)
    
    sess.run(tf.initialize_all_variables())
    sess.run(tf.scatter_update(x_flat, linear_indices,  updates))
    
    # convert back into original shape
    x = tf.reshape(x_flat, x.get_shape())
    
    print x.eval()
    

    gives

    [[ 0.  0.  0.  0.]
     [ 5.  0.  4.  0.]
     [ 0.  0.  0.  3.]]
    

    Finally the third example is already supported with gather_nd, you write

    print tf.gather_nd(x, coords).eval()
    

    To get

    [ 5.  4.  3.]
    

    Edit, May 6

    The update x[cols,rows]=newvals can be done without using Variables (which occupy memory between session run calls) by using select with sparse_to_dense that takes vector of sparse values, or relying on dynamic_stitch

    sess = tf.InteractiveSession()
    x = tf.zeros((3, 4))
    row_indices = tf.constant([1, 1, 2])
    col_indices = tf.constant([0, 2, 3])
    
    # no automatic promotion, so specify float type
    replacement_vals = tf.constant([5, 4, 3], dtype=tf.float32)
    
    # convert to linear indexing in row-major form
    linear_indices = row_indices*x.get_shape()[1]+col_indices
    x_flat = tf.reshape(x, [-1])
    
    # use dynamic stitch, it merges the array by taking value either
    # from array1[index1] or array2[index2], if indices conflict,
    # the later one is used 
    unchanged_indices = tf.range(tf.size(x_flat))
    changed_indices = linear_indices
    x_flat = tf.dynamic_stitch([unchanged_indices, changed_indices],
                               [x_flat, replacement_vals])
    x = tf.reshape(x_flat, x.get_shape())
    print x.eval()
    
    0 讨论(0)
提交回复
热议问题