Tensorflow Dictionary lookup with String tensor

后端 未结 4 2094
名媛妹妹
名媛妹妹 2020-12-29 20:35

Is there any way to perform a dictionary lookup based on a String tensor in Tensorflow?

In plain Python, I\'d do something like

value = dictionary[ke         


        
相关标签:
4条回答
  • 2020-12-29 21:06

    TensorFlow is a data flow language with no support for data structures other than tensors. There is no map or dictionary type. However, depending on what you need, when you're using the Python wrapper it is possible to maintain a dictionary in the driver process, which executes in Python, and use it to interact with the TensorFlow graph execution. For example, you could execute one step of a TensorFlow graph within a session, return a string value to the Python driver, use it as a key into a dictionary in the driver, and use the retrieved value to determine the next computation to be requested from the session. This is probably not a good solution if the speed of these dictionary lookups is performance critical.

    0 讨论(0)
  • 2020-12-29 21:15

    You might find tensorflow.contrib.lookup helpful: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py

    https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable

    In particular, you can do:

    table = tf.contrib.lookup.HashTable(
      tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
    )
    out = table.lookup(input_tensor)
    table.init.run()
    print out.eval()
    
    0 讨论(0)
  • 2020-12-29 21:19

    If you want to run this with new TF 2.x code with eager execution enabled by default. Below is the quick code snippet.

    import tensorflow as tf
    
    # build a lookup table
    table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant([0, 1, 2, 3]),
            values=tf.constant([10, 11, 12, 13]),
        ),
        default_value=tf.constant(-1),
        name="class_weight"
    )
    
    # now let us do a lookup
    input_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])
    out = table.lookup(input_tensor)
    print(out)
    

    Output:

    tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)
    
    0 讨论(0)
  • 2020-12-29 21:21

    tf.gather can help you, but it only gets values of list. You can convert dictionary into key and value lists, and then apply tf.gather. Example:

    # Your dict
    dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
    # concrete query
    query_list = ['a', 'c']
    
    # unpack key and value lists
    key, value = list(zip(*dict_.items()))
    # map query list to list -> [0, 2]
    query_list = [i for i, s in enumerate(key) if s in query_list]
    
    # query as tensor
    query = tf.placeholder(tf.int32, shape=[None])
    # convert value list to tensor
    vl_tf = tf.constant(value)
    # get value
    my_vl = tf.gather(vl_tf, query)
    
    # session run
    sess = tf.InteractiveSession()
    sess.run(my_vl, feed_dict={query:query_list})
    
    0 讨论(0)
提交回复
热议问题