How to fetch specific rows from a tensor in Tensorflow?

倾然丶 夕夏残阳落幕 提交于 2021-02-19 01:30:09

问题


I have a tensor defined as follows:

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]]))

I also have an array of indexes of rows to be fetched from tensor:

idx = tf.constant([0, 2])

Now I want to take a subset of temp_var at those indexes i.e. idx

I know that to take a single index or a slice, we can do something like

temp_var[single_row_index, :]

or

temp_var[start:end, :]

But how to fetch rows indicated by idx array? Something like temp_var[idx, :] ?


回答1:


The tf.gather() op does exactly what you need: it selects rows from a matrix (or in general (N-1)-dimensional slices from an N-dimensional tensor). Here's how it would work in your case:

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]))
idx = tf.constant([0, 2])

rows = tf.gather(temp_var, idx)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

print(sess.run(rows))  # ==> [[1, 2, 3], [7, 8, 9]]


来源:https://stackoverflow.com/questions/38743538/how-to-fetch-specific-rows-from-a-tensor-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!