问题
I have the following TensorFlow tensors.
tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)
tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]
I wish to use the values stored in tensor 3 and tensor 4 to make a tuple and query the element at position given by the tuple in tensor 5. For example, let's say 0th element in tensor 3, that is tensor3[0]=5 and tensor4[0]=99. So the tuple becomes (5,99). I wish to look up the value of element (5,99) in tensor 5. I wish to do it for all elements in Tensor3 and Tensor4 in a batch processing manner. That is I do not want to loop over all values in the range of (len(Tensor3)). I did the following to achieve this.
tensor6 = tensor5[tensor3[0],tensor4[0]]
But tensor6 has the shape (255,255) where as I was hoping to get a tensor of shape (len(tensor3),len(tensor3)). I wanted to evaluate tensor5 at all possible locations in len(tensor3). That is at (0,0),...(1000,1000),....(2000,2000),...
. I am using TensorFlow version 1.12.0. How can I achieve this?
回答1:
I have managed to get something working in Tensorflow v 1.12, but do let me know if it is the expected code:
import tensorflow as tf
print(tf.__version__)
import numpy as np
tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)
tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]
elems = (tensor3, tensor4)
a = tf.map_fn(lambda x: tensor5[x[0], x[1]], elems, dtype=tf.int32)
print(tf.Session().run(a))
Based on the comment below I'd like to add an explanation for the map_fn
used in the code. Since for
loops are not supported without eager_execution, map_fn
is (sort of) equivalent to for
loops.
A map_fn
has the following parameters: operation_performed
, input_arguments
, optional_dtype
. What happens under the hood is that a for
loop is run along the length of the values in input_arguments
(which must contain an iterable object) and then for each value obtained operation_performed
is performed. For further clarification please refer docs.
The names given to the arguments of the function is my way of interpreting them, as I'd like understand it, and is not given in the official docs. :)
来源:https://stackoverflow.com/questions/61835494/access-elements-of-a-tensor