问题
Given indices
with shape [batch_size, sequence_len]
, updates
with shape [batch_size, sequence_len, sampled_size]
, to_shape
with shape [batch_size, sequence_len, vocab_size]
, where vocab_size
>> sampled_size
, I'd like to use tf.scatter
to map the updates
to a huge tensor with to_shape
, such that to_shape[bs, indices[bs, sz]] = updates[bs, sz]
. That is, I'd like to map the updates
to to_shape
row by row. Please note that sequence_len
and sampled_size
are scalar tensors, while others are fixed. I tried to do the following:
new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)
But I got an error:
ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]
Could you please tell me how to use scatter_nd
properly? Thanks in advance!
回答1:
So assuming you have:
- A tensor
updates
with shape[batch_size, sequence_len, sampled_size]
. - A tensor
indices
with shape[batch_size, sequence_len, sampled_size]
.
Then you do:
import tensorflow as tf
# Create updates and indices...
# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)
tf.scatter_nd takes an indices
tensor, an updates
tensor and some shape. updates
is the original tensor, and the shape is just the desired output shape, so [batch_size, sequence_len, vocab_size]
. Now, indices
is more complicated. Since your output has 3 dimensions (rank 3), for each of the elements in updates
you need 3 indices to determine where in the output each element is going to be placed. So the shape of the indices
parameter should be the same as updates
with an additional dimension of size 3. In this case, we want the first to dimensions to be the same, but we still have to specify the 3 indices. So we use tf.meshgrid to generate the indices that we need and we tile them along the third dimension (the first and second index for each element vector in the last dimension of updates
is the same). Finally, we stack these indices with the previously created mapping indices and we have our full 3-dimensional indices.
回答2:
I think you might be looking for this.
def permute_batched_tensor(batched_x, batched_perm_ids):
indices = tf.tile(tf.expand_dims(batched_perm_ids, 2), [1,1,batched_x.shape[2]])
# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batched_x.shape[0]),
tf.range(batched_x.shape[2]), indexing="ij")
i1 = tf.tile(i1[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
i2 = tf.tile(i2[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
# Create final indices
idx = tf.stack([i1, indices, i2], axis=-1)
temp = tf.scatter_nd(idx, batched_x, batched_x.shape)
return temp
来源:https://stackoverflow.com/questions/45162998/proper-usage-of-tf-scatter-nd-in-tensorflow-r1-2