TensorFlow getting elements of every row for specific columns

后端 未结 5 712
再見小時候
再見小時候 2021-02-08 02:57

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable



        
5条回答
  •  长发绾君心
    2021-02-08 03:44

    You can use one hot method to create a one_hot array and use it as a boolean mask to select the indices you'd like.

    A = tf.Variable([[1, 2], [3, 4]])
    index = tf.Variable([0, 1])
    
    one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool)
    output = tf.boolean_mask(A, one_hot_mask)
    

提交回复
热议问题