Compute the pairwise distance between each pair of the two collections of inputs in TensorFlow

前端 未结 2 1406
轻奢々
轻奢々 2020-12-19 21:25

I have two collections. One consists of m1 points in k dimensions and another one of m2 points in k dimensions. I n

相关标签:
2条回答
  • 2020-12-19 21:51

    After a few hours I finally found how to do this in Tensorflow. My solution works only for Eucledian distance and is pretty verbose. I also do not have a mathematical proof (just a lot of handwaving, which I hope to make more rigorous):

    import tensorflow as tf
    import numpy as np
    from scipy.spatial.distance import cdist
    
    M1, M2, K = 3, 4, 2
    
    # Scipy calculation
    a = np.random.rand(M1, K).astype(np.float32)
    b = np.random.rand(M2, K).astype(np.float32)
    print cdist(a, b, 'euclidean'), '\n'
    
    # TF calculation
    A = tf.Variable(a)
    B = tf.Variable(b)
    
    p1 = tf.matmul(
        tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1),
        tf.ones(shape=(1, M2))
    )
    p2 = tf.transpose(tf.matmul(
        tf.reshape(tf.reduce_sum(tf.square(B), 1), shape=[-1, 1]),
        tf.ones(shape=(M1, 1)),
        transpose_b=True
    ))
    
    res = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, B, transpose_b=True))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(res)
    
    0 讨论(0)
  • 2020-12-19 22:00

    This will do it for tensors of arbitrary dimensionality (i.e. containing (..., N, d) vectors). Note that it isn't between collections (i.e. not like scipy.spatial.distance.cdist) it's instead within a single batch of vectors (i.e. like scipy.spatial.distance.pdist)

    import tensorflow as tf
    import string
    
    def pdist(arr):
        """Pairwise Euclidean distances between vectors contained at the back of tensors.
    
        Uses expansion: (x - y)^T (x - y) = x^Tx - 2x^Ty + y^Ty 
    
        :param arr: (..., N, d) tensor
        :returns: (..., N, N) tensor of pairwise distances between vectors in the second-to-last dim.
        :rtype: tf.Tensor
    
        """
        shape = tuple(arr.get_shape().as_list())
        rank_ = len(shape)
        N, d = shape[-2:]
    
        # Build a prefix from the array without the indices we'll use later.
        pref = string.ascii_lowercase[:rank_ - 2]
    
        # Outer product of points (..., N, N)
        xxT = tf.einsum('{0}ni,{0}mi->{0}nm'.format(pref), arr, arr)
    
        # Inner product of points. (..., N)
        xTx = tf.einsum('{0}ni,{0}ni->{0}n'.format(pref), arr, arr)
    
        # (..., N, N) inner products tiled.
        xTx_tile = tf.tile(xTx[..., None], (1,) * (rank_ - 1) + (N,))
    
        # Build the permuter. (sigh, no tf.swapaxes yet)
        permute = list(range(rank_))
        permute[-2], permute[-1] = permute[-1], permute[-2]
    
        # dists = (x^Tx - 2x^Ty + y^Tx)^(1/2). Note the axis swapping is necessary to 'pair' x^Tx and y^Ty
        return tf.sqrt(xTx_tile - 2 * xxT + tf.transpose(xTx_tile, permute))
    
    0 讨论(0)
提交回复
热议问题