tensorflow中embedding_lookup, tf.gather以及tf.nn.embedding_lookup_sparse的理解

霸气de小男生 提交于 2019-12-07 17:07:32

s###1. tf.nn.embedding_lookup()
函数签名如下:

embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)

参数说明:params参数是一些tensor组成的列表或者单个的tensor。ids一个整型的tensor,每个元素将代表要在params中取的每个元素的第0维的逻辑index,这个逻辑index是由partition_strategy来指定的。 paratition_strategy用来设定ids的切分方式。目前有两种方式分别是div和mod。其中mod的切分方式是如果对[1,2,3,4,5,6,7]进行切分则结果为[1,4,7],[2,5],[3,6]。如果是div的切分方式则是[1,2,3]、[4,5]、[6,7]。这两种切分方式在无法均匀切分的情况下都是将前(max_id+1)%len(params)个切分多分配一个元素。即在上面的例子中要求第一个切分是3个元素,其他的都是两个元素。params中有多少个tensor就进行多少切分。每个切分中的数字即是该元素的在ids中对应的index值,即对于mod切分来说,如果ids中的一个值为1,则对于params的第0个元素的第0维取值为0,ids中一个值为4,则对应params中的第0个元素的第0维取值为1。
一个示例代码如下:

def test_embedding_lookup():
    a = np.arange(8).reshape(2,4)
    b = np.arange(8,12).reshape(1,4)
    c = np.arange(12, 20).reshape(2,4)
    print(a)
    print(b)
    print(c)

    a = tf.Variable(a)
    b = tf.Variable(b)
    c = tf.Variable(c)

    t = tf.nn.embedding_lookup([a,b,c], ids=[0,1,2,3])
    # 此处如果ids=[0,1,2,3]不会报错,因为此时并没有发现b比c要少一行,程序能够正常的执行,但是如果出现参数4了,因为
    # 程序的partition要求在无法进行均匀切分时,前面的(max_id+1)%len(params)个param的切分可以多一个。在此例子中
    # 正确的id应该是params中的第一元素的id为[0,3], 第二元素的id应该为[1,4], 第三个元素的id应该为[2]。所以正确的param
    # 应该是(a,c,b)或者(c,a,b),总之b应该放在最后面
    # 本例的运算结果为:
    '''
    [[ 0  1  2  3]
    [ 8  9 10 11]
    [12 13 14 15]
    [ 4  5  6  7]]
    '''
    #但是本例中的[a,b,c]顺序其实是错误的

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    m = sess.run(t)
    print(m)

函数的功能:该函数的作用就是将params中的tensor按照切分方式的逻辑顺序在第0维合并成一个大的逻辑tensor,然后根据这个ids中的数字在这个大的逻辑tensor中取出对于的子tensor。ids中的每个元素对应大的逻辑tensor的第0维取该值后获得到的子tensor。所以最后结果的shape为shape(ids) + shape(params)[1:].

2. tf.gather()

函数签名如下:

gather(
    params,
    indices,
    validate_indices=None,
    name=None
)

参数说明:params是一个tensor, indices是个值为int的tensor用来指定要从params取得元素的第0维的index。该函数可看成是tf.nn.embedding_lookup()的特殊形式,所以功能与其类是,即将其看成是embedding_lookup函数的params参数内只有一个tensor时的情形。返回的结果类型如下:

 # Scalar indices
    output[:, ..., :] = params[indices, :, ... :]

    # Vector indices
    output[i, :, ..., :] = params[indices[i], :, ... :]

    # Higher rank indices
    output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]

3. tf.nn.embedding_lookup_sparse()

该函数的函数签名如下:

embedding_lookup_sparse(
    params,
    sp_ids,
    sp_weights,
    partition_strategy='mod',
    name=None,
    combiner=None,
    max_norm=None
)

参数说明:params参数与embedding_lookup函数一致。根据自己写的测试代码来看,该函数的实现的功能和api上的描述并不一样。

代码如下:

def test_lookup_sparse():
    a = np.arange(8).reshape(2, 4)
    b = np.arange(8, 16).reshape(2, 4)
    c = np.arange(12, 20).reshape(2, 4)

    print(a)
    print(b)
    print(c)

    a = tf.Variable(a, dtype=tf.float32)
    b = tf.Variable(b, dtype=tf.float32)
    c = tf.Variable(c, dtype=tf.float32)

    idx = tf.SparseTensor(indices=[[0,0], [0,2], [1,0], [1, 1]], values=[1,2,2,0], dense_shape=(2,3))
    result = tf.nn.embedding_lookup_sparse((a,c,b), idx, None, combiner="sum")

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    r = sess.run(result)
    print(r)
    '''
    根据程序的测试结果来看,这里的params的结合方式并不是成为一个逻辑大tensor,而是直接变成一个大的tensor,在该tensor的在第0维扩张
    '''
    '''
    a,b,c 为:[[0 1 2 3]
 [4 5 6 7]]

[[ 8  9 10 11]
 [12 13 14 15]]

[[12 13 14 15]
 [16 17 18 19]]

 在实现中好像将它们结合成一个大的tensor了,而不是使用partition,即实现的结果是
 [[[0 1 2 3]
 [4 5 6 7]]
[[ 8  9 10 11]
 [12 13 14 15]]
[[12 13 14 15]
 [16 17 18 19]]
 ]
 最后的结果为:
 [[[ 20.  22.  24.  26.]
  [ 28.  30.  32.  34.]]

 [[  8.  10.  12.  14.]
  [ 16.  18.  20.  22.]]]

    '''

上面的测试代码有问题,代码的实现并没有问题,在使用了[a,c,b]的列表时功能就正确了,而上面的测试代码使用的是(a,c,b)的元组所以有问题。
本博客的代码地址

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!