Optimized method for calculating cosine distance in Python

前端 未结 8 891
被撕碎了的回忆
被撕碎了的回忆 2021-02-14 21:27

I wrote a method to calculate the cosine distance between two arrays:

def cosine_distance(a, b):
    if len(a) != len(b):
        return False
    numerator = 0
         


        
8条回答
  •  终归单人心
    2021-02-14 21:50

    Using the C code inside of SciPy wins big for long input arrays. Using simple and direct Python wins for short input arrays; Darius Bacon's izip()-based code benchmarked out best. Thus, the ultimate solution is to decide which one to use at runtime, based on the length of the input arrays:

    from scipy.spatial.distance import cosine as scipy_cos_dist
    
    from itertools import izip
    from math import sqrt
    
    def cosine_distance(a, b):
        len_a = len(a)
        assert len_a == len(b)
        if len_a > 200:  # 200 is a magic value found by benchmark
            return scipy_cos_dist(a, b)
        # function below is basically just Darius Bacon's code
        ab_sum = a_sum = b_sum = 0
        for ai, bi in izip(a, b):
            ab_sum += ai * bi
            a_sum += ai * ai
            b_sum += bi * bi
        return 1 - ab_sum / sqrt(a_sum * b_sum)
    

    I made a test harness that tested the functions with different length inputs, and found that around length 200 the SciPy function started to win. The bigger the input arrays, the bigger it wins. For very short length arrays, say length 3, the simpler code wins. This function adds a tiny amount of overhead to decide which way to do it, then does it the best way.

    In case you are interested, here is the test harness:

    from darius2 import cosine_distance as fn_darius2
    fn_darius2.__name__ = "fn_darius2"
    
    from ult import cosine_distance as fn_ult
    fn_ult.__name__ = "fn_ult"
    
    from scipy.spatial.distance import cosine as fn_scipy
    fn_scipy.__name__ = "fn_scipy"
    
    import random
    import time
    
    lst_fn = [fn_darius2, fn_scipy, fn_ult]
    
    def run_test(fn, lst0, lst1, test_len):
        start = time.time()
        for _ in xrange(test_len):
            fn(lst0, lst1)
        end = time.time()
        return end - start
    
    for data_len in range(50, 500, 10):
        a = [random.random() for _ in xrange(data_len)]
        b = [random.random() for _ in xrange(data_len)]
        print "len(a) ==", len(a)
        test_len = 10**3
        for fn in lst_fn:
            n = fn.__name__
            r = fn(a, b)
            t = run_test(fn, a, b, test_len)
            print "%s:\t%f seconds, result %f" % (n, t, r)
    

提交回复
热议问题