Broadcast Annoy object in Spark (for nearest neighbors)?

前端 未结 2 657
庸人自扰
庸人自扰 2021-01-03 05:05

As Spark\'s mllib doesn\'t have nearest-neighbors functionality, I\'m trying to use Annoy for approximate Nearest Neighbors. I try to broadcast the Annoy object and pass it

相关标签:
2条回答
  • 2021-01-03 06:01

    Just in case anyone else is following along here like I was, you'll need to import Annoy in the mapPartitions function, else you'll still get pickling errors. Here's my completed example based on the above:

    from annoy import AnnoyIndex
    
    from pyspark import SparkFiles
    from pyspark import SparkContext
    from pyspark import SparkConf
    
    import random
    random.seed(42)
    
    f = 1024
    t = AnnoyIndex(f)
    allvectors = []
    for i in range(100):
        v = [random.gauss(0, 1) for z in range(f)]
        t.add_item(i, v)
        allvectors.append((i, v))
    
    t.build(10)
    t.save("index.ann")
    
    def find_neighbors(i):
        from annoy import AnnoyIndex
        ai = AnnoyIndex(f)
        ai.load(SparkFiles.get("index.ann"))
        return (ai.get_nns_by_vector(vector=x[1], n=5) for x in i)
    
    with SparkContext(conf=SparkConf().setAppName("myannoy")) as sc:
      sc.addFile("index.ann")
      sparkvectors = sc.parallelize(allvectors)
      sparkvectors.mapPartitions(find_neighbors).first()
    
    0 讨论(0)
  • 2021-01-03 06:11

    I've never used Annoy but I am pretty sure that the package description explains what is going on here:

    It also creates large read-only file-based data structures that are mmapped into memory so that many processes may share the same data.

    Since it is using memory mapped indexes when you serialize it and pass it to the workers all data is lost on the way.

    Try something like this instead:

    from pyspark import SparkFiles
    
    t.save("index.ann")
    sc.addPyFile("index.ann")
    
    def find_neighbors(iter):
        t = AnnoyIndex(f)
        t.load(SparkFiles.get("index.ann"))
        return (t.get_nns_by_vector(vector=x[1], n=5) for x in iter)
    
    sparkvectors.mapPartitions(find_neighbors).first()
    ## [0, 13, 12, 6, 4]
    
    0 讨论(0)
提交回复
热议问题