Is there any way to perform a dictionary lookup based on a String tensor in Tensorflow?
In plain Python, I\'d do something like
value = dictionary[ke
TensorFlow is a data flow language with no support for data structures other than tensors. There is no map or dictionary type. However, depending on what you need, when you're using the Python wrapper it is possible to maintain a dictionary in the driver process, which executes in Python, and use it to interact with the TensorFlow graph execution. For example, you could execute one step of a TensorFlow graph within a session, return a string value to the Python driver, use it as a key into a dictionary in the driver, and use the retrieved value to determine the next computation to be requested from the session. This is probably not a good solution if the speed of these dictionary lookups is performance critical.
You might find tensorflow.contrib.lookup
helpful:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py
https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable
In particular, you can do:
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
)
out = table.lookup(input_tensor)
table.init.run()
print out.eval()
If you want to run this with new TF 2.x code with eager execution enabled by default. Below is the quick code snippet.
import tensorflow as tf
# build a lookup table
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=tf.constant([0, 1, 2, 3]),
values=tf.constant([10, 11, 12, 13]),
),
default_value=tf.constant(-1),
name="class_weight"
)
# now let us do a lookup
input_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])
out = table.lookup(input_tensor)
print(out)
Output:
tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)
tf.gather can help you, but it only gets values of list. You can convert dictionary into key and value lists, and then apply tf.gather. Example:
# Your dict
dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
# concrete query
query_list = ['a', 'c']
# unpack key and value lists
key, value = list(zip(*dict_.items()))
# map query list to list -> [0, 2]
query_list = [i for i, s in enumerate(key) if s in query_list]
# query as tensor
query = tf.placeholder(tf.int32, shape=[None])
# convert value list to tensor
vl_tf = tf.constant(value)
# get value
my_vl = tf.gather(vl_tf, query)
# session run
sess = tf.InteractiveSession()
sess.run(my_vl, feed_dict={query:query_list})