Executors: How to synchronously wait until all tasks have finished if tasks are created recursively?

前端 未结 9 1428
傲寒
傲寒 2020-12-31 09:37

My question is strongly related to this one here. As was posted there, I would like the main thread to wait until the work queue is empty and all tasks have finished. The pr

相关标签:
9条回答
  • 2020-12-31 10:11

    Thanks a lot for all your suggestions!

    In the end I opted for something that I believe to be reasonably simple. I found out that CountDownLatch is almost what I need. It blocks until the counter reaches 0. The only problem is that it can only count down, not up, and thus does not work in the dynamic setting I have where tasks can submit new tasks. I hence implemented a new class CountLatch that offers additional functionality. (see below) This class I then use as follows.

    Main thread calls latch.awaitZero(), blocking until latch reaches 0.

    Any thread, before calling executor.execute(..) calls latch.increment().

    Any task, just before completing, calls latch.decrement().

    When the last task terminates, the counter will reach 0 and thus release the main thread.

    Further suggestions and feedback are most welcome!

    public class CountLatch {
    
    @SuppressWarnings("serial")
    private static final class Sync extends AbstractQueuedSynchronizer {
    
        Sync(int count) {
            setState(count);
        }
    
        int getCount() {
            return getState();
        }
    
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
    
        protected int acquireNonBlocking(int acquires) {
            // increment count
            for (;;) {
                int c = getState();
                int nextc = c + 1;
                if (compareAndSetState(c, nextc))
                    return 1;
            }
        }
    
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    
    private final Sync sync;
    
    public CountLatch(int count) {
        this.sync = new Sync(count);
    }
    
    public void awaitZero() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    
    public boolean awaitZero(long timeout, TimeUnit unit) throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
    
    public void increment() {
        sync.acquireNonBlocking(1);
    }
    
    public void decrement() {
        sync.releaseShared(1);
    }
    
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
    
    }
    

    Note that the increment()/decrement() calls can be encapsulated into a customized Executor subclass as was suggested, for instance, by Sami Korhonen, or with beforeExecute and afterExecute as was suggested by impl. See here:

    public class CountingThreadPoolExecutor extends ThreadPoolExecutor {
    
    protected final CountLatch numRunningTasks = new CountLatch(0);
    
    public CountingThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit,
            BlockingQueue<Runnable> workQueue) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    }
    
    @Override
    public void execute(Runnable command) {
        numRunningTasks.increment();
        super.execute(command);
    }
    
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        numRunningTasks.decrement();
        super.afterExecute(r, t);
    }
    
    /**
     * Awaits the completion of all spawned tasks.
     */
    public void awaitCompletion() throws InterruptedException {
        numRunningTasks.awaitZero();
    }
    
    /**
     * Awaits the completion of all spawned tasks.
     */
    public void awaitCompletion(long timeout, TimeUnit unit) throws InterruptedException {
        numRunningTasks.awaitZero(timeout, unit);
    }
    
    }
    
    0 讨论(0)
  • 2020-12-31 10:11

    One of the suggested options in the answers you link to is to use a CompletionService

    You could replace the busy waiting in your main thread with:

    while (true) {
        Future<?> f = completionService.take(); //blocks until task completes
        if (executor.getQueue().isEmpty()
             && numTasks.longValue() == executor.getCompletedTaskCount()) break;
    }
    

    Note that getCompletedTaskCount only returns an approximate number so you might need to find a better exit condition.

    0 讨论(0)
  • 2020-12-31 10:15

    This one was actually rather interesting problem to solve. I must warn that I have not tested the code fully.

    The idea is to simply track the task execution:

    • if task is successfully queued, counter is incremented by one
    • if task is cancelled and it has not been executed, counter is decremented by one
    • if task has been executed, counter is decremented by one

    When shutdown is called and there are pending tasks, delegate will not call shutdown on the actual ExecutorService. It will allow queuing new tasks until pending task count reaches zero and shutdown is called on actual ExecutorService.

    public class ResilientExecutorServiceDelegate implements ExecutorService {
        private final ExecutorService executorService;
        private final AtomicInteger pendingTasks;
        private final Lock readLock;
        private final Lock writeLock;
        private boolean isShutdown;
    
        public ResilientExecutorServiceDelegate(ExecutorService executorService) {
            ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
            this.pendingTasks = new AtomicInteger();
            this.readLock = readWriteLock.readLock();
            this.writeLock = readWriteLock.writeLock();
            this.executorService = executorService;
            this.isShutdown = false;
        }
    
        private <T> T addTask(Callable<T> task) {
            T result;
            boolean success = false;
            // Increment pending tasks counter
            incrementPendingTaskCount();
            try {
                // Call service
                result = task.call();
                success = true;
            } catch (RuntimeException exception) {
                throw exception;
            } catch (Exception exception) {
                throw new RejectedExecutionException(exception);
            } finally {
                if (!success) {
                    // Decrement pending tasks counter
                    decrementPendingTaskCount();
                }
            }
            return result;
        }
    
        private void incrementPendingTaskCount() {
            pendingTasks.incrementAndGet();
        }
    
        private void decrementPendingTaskCount() {
            readLock.lock();
            if (pendingTasks.decrementAndGet() == 0 && isShutdown) {
                try {
                    // Shutdown
                    executorService.shutdown();
                } catch (Throwable throwable) {
                }
            }
            readLock.unlock();
        }
    
        @Override
        public void execute(final Runnable task) {
            // Add task
            addTask(new Callable<Object>() {
                @Override
                public Object call() {
                    executorService.execute(new Runnable() {
                        @Override
                        public void run() {
                            try {
                                task.run();
                            } finally {
                                decrementPendingTaskCount();
                            }
                        }
                    });
                    return null;
                }
            });
        }
    
        @Override
        public boolean awaitTermination(long timeout, TimeUnit unit)
                throws InterruptedException {
            // Call service
            return executorService.awaitTermination(timeout, unit);
        }
    
        @Override
        public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
                throws InterruptedException {
            // It's ok to increment by just one
            incrementPendingTaskCount();
            try {
                return executorService.invokeAll(tasks);
            } finally {
                decrementPendingTaskCount();
            }
        }
    
        @Override
        public <T> List<Future<T>> invokeAll(
                Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
                throws InterruptedException {
            // It's ok to increment by just one
            incrementPendingTaskCount();
            try {
                return executorService.invokeAll(tasks, timeout, unit);
            } finally {
                decrementPendingTaskCount();
            }
        }
    
        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
                throws InterruptedException, ExecutionException {
            // It's ok to increment by just one
            incrementPendingTaskCount();
            try {
                return executorService.invokeAny(tasks);
            } finally {
                decrementPendingTaskCount();
            }
        }
    
        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks,
                long timeout, TimeUnit unit) throws InterruptedException,
                ExecutionException, TimeoutException {
            incrementPendingTaskCount();
            try {
                return executorService.invokeAny(tasks, timeout, unit);
            } finally {
                decrementPendingTaskCount();
            }
        }
    
        @Override
        public boolean isShutdown() {
            return isShutdown;
        }
    
        @Override
        public boolean isTerminated() {
            return executorService.isTerminated();
        }
    
        @Override
        public void shutdown() {
            // Lock write lock
            writeLock.lock();
            // Set as shutdown
            isShutdown = true;
            try {
                if (pendingTasks.get() == 0) {
                    // Real shutdown
                    executorService.shutdown();
                }
            } finally {
                // Unlock write lock
                writeLock.unlock();
            }
        }
    
        @Override
        public List<Runnable> shutdownNow() {
            // Lock write lock
            writeLock.lock();
            // Set as shutdown
            isShutdown = true;
            // Unlock write lock
            writeLock.unlock();
    
            return executorService.shutdownNow();
        }
    
        @Override
        public <T> Future<T> submit(final Callable<T> task) {
            // Create execution status
            final FutureExecutionStatus futureExecutionStatus = new FutureExecutionStatus();
            // Add task
            return addTask(new Callable<Future<T>>() {
                @Override
                public Future<T> call() {
                    return new FutureDelegate<T>(
                            executorService.submit(new Callable<T>() {
                                @Override
                                public T call() throws Exception {
                                    try {
                                        // Mark as executed
                                        futureExecutionStatus.setExecuted();
                                        // Run the actual task
                                        return task.call();
                                    } finally {
                                        decrementPendingTaskCount();
                                    }
                                }
                            }), futureExecutionStatus);
                }
            });
        }
    
        @Override
        public Future<?> submit(final Runnable task) {
            // Create execution status
            final FutureExecutionStatus futureExecutionStatus = new FutureExecutionStatus();
            // Add task
            return addTask(new Callable<Future<?>>() {
                @Override
                @SuppressWarnings("unchecked")
                public Future<?> call() {
                    return new FutureDelegate<Object>(
                            (Future<Object>) executorService.submit(new Runnable() {
                                @Override
                                public void run() {
                                    try {
                                        // Mark as executed
                                        futureExecutionStatus.setExecuted();
                                        // Run the actual task
                                        task.run();
                                    } finally {
                                        decrementPendingTaskCount();
                                    }
                                }
                            }), futureExecutionStatus);
                }
            });
        }
    
        @Override
        public <T> Future<T> submit(final Runnable task, final T result) {
            // Create execution status
            final FutureExecutionStatus futureExecutionStatus = new FutureExecutionStatus();
            // Add task
            return addTask(new Callable<Future<T>>() {
                @Override
                public Future<T> call() {
                    return new FutureDelegate<T>(executorService.submit(
                            new Runnable() {
                                @Override
                                public void run() {
                                    try {
                                        // Mark as executed
                                        futureExecutionStatus.setExecuted();
                                        // Run the actual task
                                        task.run();
                                    } finally {
                                        decrementPendingTaskCount();
                                    }
                                }
                            }, result), futureExecutionStatus);
                }
            });
        }
    
        private class FutureExecutionStatus {
            private volatile boolean executed;
    
            public FutureExecutionStatus() {
                executed = false;
            }
    
            public void setExecuted() {
                executed = true;
            }
    
            public boolean isExecuted() {
                return executed;
            }
        }
    
        private class FutureDelegate<T> implements Future<T> {
            private Future<T> future;
            private FutureExecutionStatus executionStatus;
    
            public FutureDelegate(Future<T> future,
                    FutureExecutionStatus executionStatus) {
                this.future = future;
                this.executionStatus = executionStatus;
            }
    
            @Override
            public boolean cancel(boolean mayInterruptIfRunning) {
                boolean cancelled = future.cancel(mayInterruptIfRunning);
                if (cancelled) {
                    // Lock read lock
                    readLock.lock();
                    // If task was not executed
                    if (!executionStatus.isExecuted()) {
                        decrementPendingTaskCount();
                    }
                    // Unlock read lock
                    readLock.unlock();
                }
                return cancelled;
            }
    
            @Override
            public T get() throws InterruptedException, ExecutionException {
                return future.get();
            }
    
            @Override
            public T get(long timeout, TimeUnit unit) throws InterruptedException,
                    ExecutionException, TimeoutException {
                return future.get(timeout, unit);
            }
    
            @Override
            public boolean isCancelled() {
                return future.isCancelled();
            }
    
            @Override
            public boolean isDone() {
                return future.isDone();
            }
        }
    }
    
    0 讨论(0)
提交回复
热议问题