How can I load data that can't be pickled in each Spark executor?

浪子不回头ぞ 提交于 2019-12-13 01:43:47

问题


I'm using the NoAho library which is written in Cython. Its internal trie cannot be pickled: if I load it on the master node, I never get matches for operations that execute in workers.

Since I would like to use the same trie in each Spark executor, I found a way to load the trie lazily, inspired by this spaCy on Spark issue.

global trie

def get_match(text):
    # 1. Load trie if needed
    global trie
    try:
        trie
    except NameError:
        from noaho import NoAho

        trie = NoAho()
        trie.add(key_text='ms windows', payload='Windows 2000')
        trie.add(key_text='ms windows 2000', payload='Windows 2000')
        trie.add(key_text='windows 2k', payload='Windows 2000')
        ...

    # 2. Find an actual match to get they payload back
    return trie.findall_long(text)

While this works, all .add() calls are performed for every Spark job, which takes around one minute. Since I'm not sure "Spark job" is the correct term, I'll be more explicit: I use Spark in a Jupyter notebook, and every time I run a cell that needs the get_match() function, the trie is never cached and takes one minute to load the tries, which dominates the run time.

Is there anything I can do to ensure the trie gets cached? Or is there a better solution to my problem?


回答1:


One thing you can try is to use a singleton module to load and initialize the trie. Basically all you need is a separate module with something like this:

  • trie_loader.py

    from noaho import NoAho
    
    def load():
        trie = NoAho()
        trie.add('ms windows', 'Windows 2000')
        trie.add('ms windows 2000', 'Windows 2000')
        trie.add('windows 2k', 'Windows 2000')
        return trie
    
    trie  = load()
    

and distribute this using standard Spark tools:

sc.addPyFile("trie_loader.py")
import trie_loader

rdd = sc.parallelize(["ms windows", "Debian GNU/Linux"])
rdd.map(lambda x: (x, trie_loader.trie.find_long(x))).collect()
## [('ms windows', (0, 10, 'Windows 2000')),
##  ('Debian GNU/Linux', (None, None, None))]

This should load required data every time Python process executor is started instead of loading it when data is accessed. I am not sure if it can help here but it is worth a try.



来源:https://stackoverflow.com/questions/35500196/how-can-i-load-data-that-cant-be-pickled-in-each-spark-executor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!