How can we use tqdm in a parallel execution with joblib?

后端 未结 6 678
轮回少年
轮回少年 2021-02-05 01:25

I want to run a function in parallel, and wait until all parallel nodes are done, using joblib. Like in the example:

from math import sqrt
from joblib import Par         


        
6条回答
  •  臣服心动
    2021-02-05 02:08

    Here's possible workaround

    def func(x):
        time.sleep(random.randint(1, 10))
        return x
    
    def text_progessbar(seq, total=None):
        step = 1
        tick = time.time()
        while True:
            time_diff = time.time()-tick
            avg_speed = time_diff/step
            total_str = 'of %n' % total if total else ''
            print('step', step, '%.2f' % time_diff, 
                  'avg: %.2f iter/sec' % avg_speed, total_str)
            step += 1
            yield next(seq)
    
    all_bar_funcs = {
        'tqdm': lambda args: lambda x: tqdm(x, **args),
        'txt': lambda args: lambda x: text_progessbar(x, **args),
        'False': lambda args: iter,
        'None': lambda args: iter,
    }
    
    def ParallelExecutor(use_bar='tqdm', **joblib_args):
        def aprun(bar=use_bar, **tq_args):
            def tmp(op_iter):
                if str(bar) in all_bar_funcs.keys():
                    bar_func = all_bar_funcs[str(bar)](tq_args)
                else:
                    raise ValueError("Value %s not supported as bar type"%bar)
                return Parallel(**joblib_args)(bar_func(op_iter))
            return tmp
        return aprun
    
    aprun = ParallelExecutor(n_jobs=5)
    
    a1 = aprun(total=25)(delayed(func)(i ** 2 + j) for i in range(5) for j in range(5))
    a2 = aprun(total=16)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
    a2 = aprun(bar='txt')(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
    a2 = aprun(bar=None)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
    

提交回复
热议问题