s###1. tf.nn.embedding_lookup()
函数签名如下:
embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None
)
参数说明:params参数是一些tensor组成的列表或者单个的tensor。ids一个整型的tensor,每个元素将代表要在params中取的每个元素的第0维的逻辑index,这个逻辑index是由partition_strategy来指定的。 paratition_strategy用来设定ids的切分方式。目前有两种方式分别是div和mod。其中mod的切分方式是如果对[1,2,3,4,5,6,7]进行切分则结果为[1,4,7],[2,5],[3,6]。如果是div的切分方式则是[1,2,3]、[4,5]、[6,7]。这两种切分方式在无法均匀切分的情况下都是将前(max_id+1)%len(params)个切分多分配一个元素。即在上面的例子中要求第一个切分是3个元素,其他的都是两个元素。params中有多少个tensor就进行多少切分。每个切分中的数字即是该元素的在ids中对应的index值,即对于mod切分来说,如果ids中的一个值为1,则对于params的第0个元素的第0维取值为0,ids中一个值为4,则对应params中的第0个元素的第0维取值为1。
一个示例代码如下:
def test_embedding_lookup():
a = np.arange(8).reshape(2,4)
b = np.arange(8,12).reshape(1,4)
c = np.arange(12, 20).reshape(2,4)
print(a)
print(b)
print(c)
a = tf.Variable(a)
b = tf.Variable(b)
c = tf.Variable(c)
t = tf.nn.embedding_lookup([a,b,c], ids=[0,1,2,3])
# 此处如果ids=[0,1,2,3]不会报错,因为此时并没有发现b比c要少一行,程序能够正常的执行,但是如果出现参数4了,因为
# 程序的partition要求在无法进行均匀切分时,前面的(max_id+1)%len(params)个param的切分可以多一个。在此例子中
# 正确的id应该是params中的第一元素的id为[0,3], 第二元素的id应该为[1,4], 第三个元素的id应该为[2]。所以正确的param
# 应该是(a,c,b)或者(c,a,b),总之b应该放在最后面
# 本例的运算结果为:
'''
[[ 0 1 2 3]
[ 8 9 10 11]
[12 13 14 15]
[ 4 5 6 7]]
'''
#但是本例中的[a,b,c]顺序其实是错误的
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
m = sess.run(t)
print(m)
函数的功能:该函数的作用就是将params中的tensor按照切分方式的逻辑顺序在第0维合并成一个大的逻辑tensor,然后根据这个ids中的数字在这个大的逻辑tensor中取出对于的子tensor。ids中的每个元素对应大的逻辑tensor的第0维取该值后获得到的子tensor。所以最后结果的shape为shape(ids) + shape(params)[1:].
2. tf.gather()
函数签名如下:
gather(
params,
indices,
validate_indices=None,
name=None
)
参数说明:params是一个tensor, indices是个值为int的tensor用来指定要从params取得元素的第0维的index。该函数可看成是tf.nn.embedding_lookup()的特殊形式,所以功能与其类是,即将其看成是embedding_lookup函数的params参数内只有一个tensor时的情形。返回的结果类型如下:
# Scalar indices
output[:, ..., :] = params[indices, :, ... :]
# Vector indices
output[i, :, ..., :] = params[indices[i], :, ... :]
# Higher rank indices
output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
3. tf.nn.embedding_lookup_sparse()
该函数的函数签名如下:
embedding_lookup_sparse(
params,
sp_ids,
sp_weights,
partition_strategy='mod',
name=None,
combiner=None,
max_norm=None
)
参数说明:params参数与embedding_lookup函数一致。根据自己写的测试代码来看,该函数的实现的功能和api上的描述并不一样。
代码如下:
def test_lookup_sparse():
a = np.arange(8).reshape(2, 4)
b = np.arange(8, 16).reshape(2, 4)
c = np.arange(12, 20).reshape(2, 4)
print(a)
print(b)
print(c)
a = tf.Variable(a, dtype=tf.float32)
b = tf.Variable(b, dtype=tf.float32)
c = tf.Variable(c, dtype=tf.float32)
idx = tf.SparseTensor(indices=[[0,0], [0,2], [1,0], [1, 1]], values=[1,2,2,0], dense_shape=(2,3))
result = tf.nn.embedding_lookup_sparse((a,c,b), idx, None, combiner="sum")
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
r = sess.run(result)
print(r)
'''
根据程序的测试结果来看,这里的params的结合方式并不是成为一个逻辑大tensor,而是直接变成一个大的tensor,在该tensor的在第0维扩张
'''
'''
a,b,c 为:[[0 1 2 3]
[4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[12 13 14 15]
[16 17 18 19]]
在实现中好像将它们结合成一个大的tensor了,而不是使用partition,即实现的结果是
[[[0 1 2 3]
[4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[12 13 14 15]
[16 17 18 19]]
]
最后的结果为:
[[[ 20. 22. 24. 26.]
[ 28. 30. 32. 34.]]
[[ 8. 10. 12. 14.]
[ 16. 18. 20. 22.]]]
'''
上面的测试代码有问题,代码的实现并没有问题,在使用了[a,c,b]的列表时功能就正确了,而上面的测试代码使用的是(a,c,b)的元组所以有问题。
本博客的代码地址
来源:CSDN
作者:xiholix
链接:https://blog.csdn.net/huhu0769/article/details/71169346