Evaluate all pair combinations of rows of two tensors in tensorflow

后端 未结 2 579
误落风尘
误落风尘 2020-12-16 13:40

I am trying to define a custom op in tensorflow, in which at one point I need to construct a matrix (z) that would contain sums of all combinations of pairs of

相关标签:
2条回答
  • 2020-12-16 14:07

    You could simply use the broadcasting ability of tensorflow.

    import tensorflow as tf
    
    x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]], dtype=tf.float32)
    y = tf.constant([[0, 1],[2, 3]], dtype=tf.float32)
    
    x_ = tf.expand_dims(x, 0)
    y_ = tf.expand_dims(y, 1)
    z = tf.reshape(tf.add(x_, y_), [-1, 2])
    # or more succinctly 
    z = tf.reshape(x[None] + y[:, None], [-1, 2])
    
    sess = tf.Session()
    sess.run(z)
    
    0 讨论(0)
  • 2020-12-16 14:07

    Option 1

    Defining z as variable and updating its rows:

    import tensorflow as tf
    from itertools import product
    
    
    x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
    y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
    
    rows_x,dim=x.get_shape()
    rows_y=y.get_shape()[0]
    
    z=tf.Variable(initial_value=tf.zeros([rows_x*rows_y,dim]),dtype=tf.float32)
    for i, (x_id, y_id) in enumerate(product(range(rows_x), range(rows_y))):
        z=tf.scatter_update(z,i,x[x_id]+y[y_id])
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        z_val=sess.run(z)
        print(z_val)
    

    This prints

    [[  0.   2.]
     [  2.   4.]
     [  2.   4.]
     [  4.   6.]
     [  4.   6.]
     [  6.   8.]
     [  6.   8.]
     [  8.  10.]]
    

    Option 2

    Creating z throw list comprehension:

    import tensorflow as tf
    from itertools import product
    
    
    x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
    y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
    
    rows_x,dim=x.get_shape().as_list()
    rows_y=y.get_shape().as_list()[0]
    
    
    z=[x[x_id]+y[y_id] for x_id in range(rows_x) for y_id in range(rows_y)]
    z=tf.reshape(z,(rows_x*rows_y,dim))
    
    with tf.Session() as sess:
        z_val=sess.run(z)
        print(z_val)
    

    Comparison: The second solution is around two times faster (only measuring the construction of z in both solutions). In particular, the timings are: first solution: 0.211 seconds, second solution:0.137 seconds.

    0 讨论(0)
提交回复
热议问题