I want to run a function in parallel, and wait until all parallel nodes are done, using joblib. Like in the example:
from math import sqrt
from joblib import Par
Modifying nth's great answer to permit a dynamic flag to use TQDM or not and to specify the total ahead of time so that the status bar fills in correctly.
from tqdm.auto import tqdm
from joblib import Parallel
class ProgressParallel(Parallel):
def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
return Parallel.__call__(self, *args, **kwargs)
def print_progress(self):
if self._total is None:
self._pbar.total = self.n_dispatched_tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()