What does tf.nn.embedding_lookup function do?

前端 未结 8 512
深忆病人
深忆病人 2020-12-02 04:18
tf.nn.embedding_lookup(params, ids, partition_strategy=\'mod\', name=None)

I cannot understand the duty of this function. Is it like a lookup table

相关标签:
8条回答
  • 2020-12-02 04:48

    When the params tensor is in high dimensions, the ids only refers to top dimension. Maybe it's obvious to most of people but I have to run the following code to understand that:

    embeddings = tf.constant([[[1,1],[2,2],[3,3],[4,4]],[[11,11],[12,12],[13,13],[14,14]],
                              [[21,21],[22,22],[23,23],[24,24]]])
    ids=tf.constant([0,2,1])
    embed = tf.nn.embedding_lookup(embeddings, ids, partition_strategy='div')
    
    with tf.Session() as session:
        result = session.run(embed)
        print (result)
    

    Just trying the 'div' strategy and for one tensor, it makes no difference.

    Here is the output:

    [[[ 1  1]
      [ 2  2]
      [ 3  3]
      [ 4  4]]
    
     [[21 21]
      [22 22]
      [23 23]
      [24 24]]
    
     [[11 11]
      [12 12]
      [13 13]
      [14 14]]]
    
    0 讨论(0)
  • 2020-12-02 04:49

    Here's an image depicting the process of embedding lookup.

    Image: Embedding lookup process

    Concisely, it gets the corresponding rows of a embedding layer, specified by a list of IDs and provide that as a tensor. It is achieved through the following process.

    1. Define a placeholder lookup_ids = tf.placeholder([10])
    2. Define a embedding layer embeddings = tf.Variable([100,10],...)
    3. Define the tensorflow operation embed_lookup = tf.embedding_lookup(embeddings, lookup_ids)
    4. Get the results by running lookup = session.run(embed_lookup, feed_dict={lookup_ids:[95,4,14]})
    0 讨论(0)
提交回复
热议问题