I am trying to define a custom op in tensorflow, in which at one point I need to construct a matrix (z
) that would contain sums of all combinations of pairs of
You could simply use the broadcasting ability of tensorflow.
import tensorflow as tf
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]], dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]], dtype=tf.float32)
x_ = tf.expand_dims(x, 0)
y_ = tf.expand_dims(y, 1)
z = tf.reshape(tf.add(x_, y_), [-1, 2])
# or more succinctly
z = tf.reshape(x[None] + y[:, None], [-1, 2])
sess = tf.Session()
sess.run(z)
Option 1
Defining z
as variable and updating its rows:
import tensorflow as tf
from itertools import product
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
rows_x,dim=x.get_shape()
rows_y=y.get_shape()[0]
z=tf.Variable(initial_value=tf.zeros([rows_x*rows_y,dim]),dtype=tf.float32)
for i, (x_id, y_id) in enumerate(product(range(rows_x), range(rows_y))):
z=tf.scatter_update(z,i,x[x_id]+y[y_id])
with tf.Session() as sess:
tf.global_variables_initializer().run()
z_val=sess.run(z)
print(z_val)
This prints
[[ 0. 2.]
[ 2. 4.]
[ 2. 4.]
[ 4. 6.]
[ 4. 6.]
[ 6. 8.]
[ 6. 8.]
[ 8. 10.]]
Option 2
Creating z
throw list comprehension:
import tensorflow as tf
from itertools import product
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
rows_x,dim=x.get_shape().as_list()
rows_y=y.get_shape().as_list()[0]
z=[x[x_id]+y[y_id] for x_id in range(rows_x) for y_id in range(rows_y)]
z=tf.reshape(z,(rows_x*rows_y,dim))
with tf.Session() as sess:
z_val=sess.run(z)
print(z_val)
Comparison: The second solution is around two times faster (only measuring the construction of z
in both solutions). In particular, the timings are:
first solution: 0.211 seconds, second solution:0.137 seconds.