Cartesian Product in Tensorflow

后端 未结 4 457
醉梦人生
醉梦人生 2020-12-02 01:56

Is there any easy way to do cartesian product in Tensorflow like itertools.product? I want to get combination of elements of two tensors (a and b),

相关标签:
4条回答
  • 2020-12-02 02:20

    A shorter solution to the same, using tf.add() for broadcasting (tested):

    import tensorflow as tf
    
    a = tf.constant([1,2,3]) 
    b = tf.constant([4,5,6,7]) 
    
    a, b = a[ None, :, None ], b[ :, None, None ]
    cartesian_product = tf.concat( [ a + tf.zeros_like( b ),
                                     tf.zeros_like( a ) + b ], axis = 2 )
    
    with tf.Session() as sess:
        print( sess.run( cartesian_product ) )
    

    will output:

    [[[1 4]
    [2 4]
    [3 4]]

    [[1 5]
    [2 5]
    [3 5]]

    [[1 6]
    [2 6]
    [3 6]]

    [[1 7]
    [2 7]
    [3 7]]]

    0 讨论(0)
  • 2020-12-02 02:21

    I'm going to assume here that both a and b are 1-D tensors.

    To get the cartesian product of the two, I would use a combination of tf.expand_dims and tf.tile:

    a = tf.constant([1,2,3]) 
    b = tf.constant([4,5,6,7]) 
    
    tile_a = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[0]])  
    tile_a = tf.expand_dims(tile_a, 2) 
    tile_b = tf.tile(tf.expand_dims(b, 0), [tf.shape(a)[0], 1]) 
    tile_b = tf.expand_dims(tile_b, 2) 
    
    cartesian_product = tf.concat([tile_a, tile_b], axis=2) 
    
    cart = tf.Session().run(cartesian_product) 
    
    print(cart.shape) 
    print(cart) 
    

    You end up with a len(a) * len(b) * 2 tensor where each combination of the elements of a and b is represented in the last dimension.

    0 讨论(0)
  • 2020-12-02 02:23

    I'm inspired by Jaba's answer. If you want to get the cartesian_product of two 2-D tensors, you can do it as following:

    input a:[N,L] and b:[M,L], get a [N*M,L] concat tensor

    tile_a = tf.tile(tf.expand_dims(a, 1), [1, M, 1])  
    tile_b = tf.tile(tf.expand_dims(b, 0), [N, 1, 1]) 
    
    cartesian_product = tf.concat([tile_a, tile_b], axis=2)   
    cartesian = tf.reshape(cartesian_product, [N*M, -1])
    
    cart = tf.Session().run(cartesian) 
    
    print(cart.shape)
    print(cart) 
    
    0 讨论(0)
  • 2020-12-02 02:30
    import tensorflow as tf
    
    a = tf.constant([0, 1, 2])
    b = tf.constant([2, 3])
    c = tf.stack(tf.meshgrid(a, b, indexing='ij'), axis=-1)
    c = tf.reshape(c, (-1, 2))
    with tf.Session() as sess:
        print(sess.run(c))
    

    Output:

    [[0 2]
     [0 3]
     [1 2]
     [1 3]
     [2 2]
     [2 3]]
    

    credit to jdehesa: link

    0 讨论(0)
提交回复
热议问题