How to index a list with a TensorFlow tensor?

后端 未结 2 764
日久生厌
日久生厌 2021-01-08 00:09

Assume a list with non concatenable objects which needs to be accessed via a look up table. So the list index will be a tensor object but this is not possible.



        
相关标签:
2条回答
  • 2021-01-08 00:29

    Tensorflow actually has support for a HashTable. See the documentation for more details.

    Here, what you could do is the following:

    table = tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, list), -1)
    

    Then just get the desired input by running

    target = table.lookup(index)
    

    Note that -1 is the default value if the key is not found. You may have to add key_dtype and value_dtype to the constructor depending on the configuration of your tensors.

    0 讨论(0)
  • 2021-01-08 00:44

    tf.gather is designed for this purpose.

    Simply run tf.gather(list, tf_look_up[index]), you'll get what you want.

    0 讨论(0)
提交回复
热议问题