Is there a simple way to track the overall progress of a joblib.Parallel execution?
I have a long-running execution composed of thousands of jobs, which I want to tr
Yet another step ahead from dano and Connor answers is to wrap whole thing as context manager:
import contextlib
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
Then you can use it like this and don't leave monkey patched code once you've done:
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
which is awesome I think and it looks similar to tqdm pandas integration.
Here's another answer to your question with the following syntax:
aprun = ParallelExecutor(n_jobs=5)
a1 = aprun(total=25)(delayed(func)(i ** 2 + j) for i in range(5) for j in range(5))
a2 = aprun(total=16)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar='txt')(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar=None)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
https://stackoverflow.com/a/40415477/232371