Access elements of a Tensor

删除回忆录丶 提交于 2020-06-27 18:28:45

问题


I have the following TensorFlow tensors.

tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]

tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)

tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]

I wish to use the values stored in tensor 3 and tensor 4 to make a tuple and query the element at position given by the tuple in tensor 5. For example, let's say 0th element in tensor 3, that is tensor3[0]=5 and tensor4[0]=99. So the tuple becomes (5,99). I wish to look up the value of element (5,99) in tensor 5. I wish to do it for all elements in Tensor3 and Tensor4 in a batch processing manner. That is I do not want to loop over all values in the range of (len(Tensor3)). I did the following to achieve this.

tensor6 = tensor5[tensor3[0],tensor4[0]]

But tensor6 has the shape (255,255) where as I was hoping to get a tensor of shape (len(tensor3),len(tensor3)). I wanted to evaluate tensor5 at all possible locations in len(tensor3). That is at (0,0),...(1000,1000),....(2000,2000),.... I am using TensorFlow version 1.12.0. How can I achieve this?


回答1:


I have managed to get something working in Tensorflow v 1.12, but do let me know if it is the expected code:

import tensorflow as tf
print(tf.__version__)
import numpy as np

tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]

tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)

tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]

elems = (tensor3, tensor4)
a = tf.map_fn(lambda x: tensor5[x[0], x[1]], elems, dtype=tf.int32)

print(tf.Session().run(a))

Based on the comment below I'd like to add an explanation for the map_fn used in the code. Since for loops are not supported without eager_execution, map_fn is (sort of) equivalent to for loops.

A map_fn has the following parameters: operation_performed, input_arguments, optional_dtype. What happens under the hood is that a for loop is run along the length of the values in input_arguments (which must contain an iterable object) and then for each value obtained operation_performed is performed. For further clarification please refer docs.

The names given to the arguments of the function is my way of interpreting them, as I'd like understand it, and is not given in the official docs. :)



来源:https://stackoverflow.com/questions/61835494/access-elements-of-a-tensor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!