Tracking progress of joblib.Parallel execution

后端 未结 8 736
隐瞒了意图╮
隐瞒了意图╮ 2020-12-24 12:08

Is there a simple way to track the overall progress of a joblib.Parallel execution?

I have a long-running execution composed of thousands of jobs, which I want to tr

相关标签:
8条回答
  • 2020-12-24 12:18

    In Jupyter tqdm starts a new line in the output each time it outputs. So for Jupyter Notebook it will be:

    from joblib import Parallel, delayed
    from datetime import datetime
    from tqdm import tqdm_notebook
    
    def myfun(x):
        return x**2
    
    results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm_notebook(range(1000)))  
    100% 1000/1000 [00:06<00:00, 143.70it/s]
    
    0 讨论(0)
  • 2020-12-24 12:22

    Why can't you simply use tqdm? The following worked for me

    from joblib import Parallel, delayed
    from datetime import datetime
    from tqdm import tqdm
    
    def myfun(x):
        return x**2
    
    results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))
    100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]
    
    0 讨论(0)
  • 2020-12-24 12:23

    The documentation you linked to states that Parallel has an optional progress meter. It's implemented by using the callback keyword argument provided by multiprocessing.Pool.apply_async:

    # This is inside a dispatch function
    self._lock.acquire()
    job = self._pool.apply_async(SafeFunction(func), args,
                kwargs, callback=CallBack(self.n_dispatched, self))
    self._jobs.append(job)
    self.n_dispatched += 1
    

    ...

    class CallBack(object):
        """ Callback used by parallel: it is used for progress reporting, and
            to add data to be processed
        """
        def __init__(self, index, parallel):
            self.parallel = parallel
            self.index = index
    
        def __call__(self, out):
            self.parallel.print_progress(self.index)
            if self.parallel._original_iterable:
                self.parallel.dispatch_next()
    

    And here's print_progress:

    def print_progress(self, index):
        elapsed_time = time.time() - self._start_time
    
        # This is heuristic code to print only 'verbose' times a messages
        # The challenge is that we may not know the queue length
        if self._original_iterable:
            if _verbosity_filter(index, self.verbose):
                return
            self._print('Done %3i jobs       | elapsed: %s',
                        (index + 1,
                         short_format_time(elapsed_time),
                        ))
        else:
            # We are finished dispatching
            queue_length = self.n_dispatched
            # We always display the first loop
            if not index == 0:
                # Display depending on the number of remaining items
                # A message as soon as we finish dispatching, cursor is 0
                cursor = (queue_length - index + 1
                          - self._pre_dispatch_amount)
                frequency = (queue_length // self.verbose) + 1
                is_last_item = (index + 1 == queue_length)
                if (is_last_item or cursor % frequency):
                    return
            remaining_time = (elapsed_time / (index + 1) *
                        (self.n_dispatched - index - 1.))
            self._print('Done %3i out of %3i | elapsed: %s remaining: %s',
                        (index + 1,
                         queue_length,
                         short_format_time(elapsed_time),
                         short_format_time(remaining_time),
                        ))
    

    The way they implement this is kind of weird, to be honest - it seems to assume that tasks will always be completed in the order that they're started. The index variable that goes to print_progress is just the self.n_dispatched variable at the time the job was actually started. So the first job launched will always finish with an index of 0, even if say, the third job finished first. It also means they don't actually keep track of the number of completed jobs. So there's no instance variable for you to monitor.

    I think your best best is to make your own CallBack class, and monkey patch Parallel:

    from math import sqrt
    from collections import defaultdict
    from joblib import Parallel, delayed
    
    class CallBack(object):
        completed = defaultdict(int)
    
        def __init__(self, index, parallel):
            self.index = index
            self.parallel = parallel
    
        def __call__(self, index):
            CallBack.completed[self.parallel] += 1
            print("done with {}".format(CallBack.completed[self.parallel]))
            if self.parallel._original_iterable:
                self.parallel.dispatch_next()
    
    import joblib.parallel
    joblib.parallel.CallBack = CallBack
    
    if __name__ == "__main__":
        print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))
    

    Output:

    done with 1
    done with 2
    done with 3
    done with 4
    done with 5
    done with 6
    done with 7
    done with 8
    done with 9
    done with 10
    [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
    

    That way, your callback gets called whenever a job completes, rather than the default one.

    0 讨论(0)
  • 2020-12-24 12:26

    Expanding on dano's answer for the newest version of the joblib library. There were a couple of changes to the internal implementation.

    from joblib import Parallel, delayed
    from collections import defaultdict
    
    # patch joblib progress callback
    class BatchCompletionCallBack(object):
      completed = defaultdict(int)
    
      def __init__(self, time, index, parallel):
        self.index = index
        self.parallel = parallel
    
      def __call__(self, index):
        BatchCompletionCallBack.completed[self.parallel] += 1
        print("done with {}".format(BatchCompletionCallBack.completed[self.parallel]))
        if self.parallel._original_iterator is not None:
          self.parallel.dispatch_next()
    
    import joblib.parallel
    joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack
    
    0 讨论(0)
  • 2020-12-24 12:27

    Text progress bar

    One more variant for those, who want text progress bar without additional modules like tqdm. Actual for joblib=0.11, python 3.5.2 on linux at 16.04.2018 and shows progress upon subtask completion.

    Redefine native class:

    class BatchCompletionCallBack(object):
        # Added code - start
        global total_n_jobs
        # Added code - end
        def __init__(self, dispatch_timestamp, batch_size, parallel):
            self.dispatch_timestamp = dispatch_timestamp
            self.batch_size = batch_size
            self.parallel = parallel
    
        def __call__(self, out):
            self.parallel.n_completed_tasks += self.batch_size
            this_batch_duration = time.time() - self.dispatch_timestamp
    
            self.parallel._backend.batch_completed(self.batch_size,
                                               this_batch_duration)
            self.parallel.print_progress()
            # Added code - start
            progress = self.parallel.n_completed_tasks / total_n_jobs
            print(
                "\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(progress * 50), progress*100)
                , end="", flush=True)
            if self.parallel.n_completed_tasks == total_n_jobs:
                print('\n')
            # Added code - end
            if self.parallel._original_iterator is not None:
                self.parallel.dispatch_next()
    
    import joblib.parallel
    joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack
    

    Define global constant before usage with total number of jobs:

    total_n_jobs = 10
    

    This will result in something like this:

    Progress: [########################################          ] 80.0%
    
    0 讨论(0)
  • 2020-12-24 12:28

    TLDR solution:

    Works with joblib 0.14.0 and tqdm 4.46.0 using python 3.5. Credits to frenzykryger for contextlib suggestions, dano and Connor for monkey patching idea.

    import contextlib
    import joblib
    from tqdm import tqdm
    from joblib import Parallel, delayed
    
    @contextlib.contextmanager
    def tqdm_joblib(tqdm_object):
        """Context manager to patch joblib to report into tqdm progress bar given as argument"""
    
        def tqdm_print_progress(self):
            if self.n_completed_tasks > tqdm_object.n:
                n_completed = self.n_completed_tasks - tqdm_object.n
                tqdm_object.update(n=n_completed)
    
        original_print_progress = joblib.parallel.Parallel.print_progress
        joblib.parallel.Parallel.print_progress = tqdm_print_progress
    
        try:
            yield tqdm_object
        finally:
            joblib.parallel.Parallel.print_progress = original_print_progress
            tqdm_object.close()
    

    You can use this the same way as described by frenzykryger

    import time
    def some_method(wait_time):
        time.sleep(wait_time)
    
    with tqdm_joblib(tqdm(desc="My method", total=10)) as progress_bar:
        Parallel(n_jobs=2)(delayed(some_method)(0.2) for i in range(10))
    

    Longer explanation:

    The solution by Jon is simple to implement, but it only measures the dispatched task. If the task takes a long time, the bar will be stuck at 100% while waiting for the last dispatched task to finish execution.

    The context manager approach by frenzykryger, improved from dano and Connor, is better, but the BatchCompletionCallBack can also be called with ImmediateResult before the task completes (See Intermediate results from joblib). This is going to get us a count that is over 100%.

    Instead of monkey patching the BatchCompletionCallBack, we can just patch the print_progress function in Parallel. The BatchCompletionCallBack already calls this print_progress anyway. If the verbose is set (i.e. Parallel(n_jobs=2, verbose=100)), the print_progress will be printing out completed tasks, though not as nice as tqdm. Looking at the code, the print_progress is a class method, so it already has self.n_completed_tasks that logs the number we want. All we have to do is just to compare this with the current state of joblib's progress and update only if there is a difference.

    This was tested in joblib 0.14.0 and tqdm 4.46.0 using python 3.5.

    0 讨论(0)
提交回复
热议问题