TensorFlow getting elements of every row for specific columns

后端 未结 5 701
再見小時候
再見小時候 2021-02-08 02:57

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable



        
5条回答
  •  囚心锁ツ
    2021-02-08 03:40

    You can extend your column indices with row indices and then use gather_nd:

    import tensorflow as tf
    
    A = tf.constant([[1, 2], [3, 4]])
    indices = tf.constant([1, 0])
    
    # prepare row indices
    row_indices = tf.range(tf.shape(indices)[0])
    
    # zip row indices with column indices
    full_indices = tf.stack([row_indices, indices], axis=1)
    
    # retrieve values by indices
    S = tf.gather_nd(A, full_indices)
    
    session = tf.InteractiveSession()
    session.run(S)
    

提交回复
热议问题