问题
I'm working on making the Dijkstra algorithm parallel. Per node threads are made to look at all the edges of the current node. This was made parallel with threads but there is too much overhead. This resulted in a longer time than the sequential version of the algorithm.
ThreadPool was added to solve this problem but i'm having trouble with waiting until the tasks are done before I can move on to the next iteration. Only after all tasks for one node is done we should move on. We need the results of all tasks before I can search for the next closest by node.
I tried doing executor.shutdown() but with this aproach it won't accept new tasks. How can we wait in the loop until every task is finished without having to declare the ThreadPoolExecutor every time. Doing this will defeat the purpose of the less overhead by using this instead of regular threads.
One thing I thought about was an BlockingQueue that add the tasks(edges). But also for this solution i'm stuck on waiting for tasks to finish without shudown().
public void apply(int numberOfThreads) {
ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numberOfThreads);
class DijkstraTask implements Runnable {
private String name;
public DijkstraTask(String name) {
this.name = name;
}
public String getName() {
return name;
}
@Override
public void run() {
calculateShortestDistances(numberOfThreads);
}
}
// Visit every node, in order of stored distance
for (int i = 0; i < this.nodes.length; i++) {
//Add task for each node
for (int t = 0; t < numberOfThreads; t++) {
executor.execute(new DijkstraTask("Task " + t));
}
//Wait until finished?
while (executor.getActiveCount() > 0) {
System.out.println("Active count: " + executor.getActiveCount());
}
//Look through the results of the tasks and get the next node that is closest by
currentNode = getNodeShortestDistanced();
//Reset the threadCounter for next iteration
this.setCount(0);
}
}
The amount of edges is divided by the number of threads. So 8 edges and 2 threads means each thread will deal with 4 edges in parallel.
public void calculateShortestDistances(int numberOfThreads) {
int threadCounter = this.getCount();
this.setCount(count + 1);
// Loop round the edges that are joined to the current node
currentNodeEdges = this.nodes[currentNode].getEdges();
int edgesPerThread = currentNodeEdges.size() / numberOfThreads;
int modulo = currentNodeEdges.size() % numberOfThreads;
this.nodes[0].setDistanceFromSource(0);
//Process the edges per thread
for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter + 1)); joinedEdge++) {
System.out.println("Start: " + (edgesPerThread * threadCounter) + ". End: " + (edgesPerThread * (threadCounter + 1) + ".JoinedEdge: " + joinedEdge) + ". Total: " + currentNodeEdges.size());
// Determine the joined edge neighbour of the current node
int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);
// Only interested in an unvisited neighbour
if (!this.nodes[neighbourIndex].isVisited()) {
// Calculate the tentative distance for the neighbour
int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
// Overwrite if the tentative distance is less than what's currently stored
if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
nodes[neighbourIndex].setDistanceFromSource(tentative);
}
}
}
//if we have a modulo above 0, the last thread will process the remaining edges
if (modulo > 0 && numberOfThreads == (threadCounter + 1)) {
for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter) + modulo); joinedEdge++) {
// Determine the joined edge neighbour of the current node
int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);
// Only interested in an unvisited neighbour
if (!this.nodes[neighbourIndex].isVisited()) {
// Calculate the tentative distance for the neighbour
int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
// Overwrite if the tentative distance is less than what's currently stored
if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
nodes[neighbourIndex].setDistanceFromSource(tentative);
}
}
}
}
// All neighbours are checked so this node is now visited
nodes[currentNode].setVisited(true);
}
Thanks for helping me!
回答1:
You should look into CyclicBarrier
or a CountDownLatch
. Both of these allow you to prevent threads starting unless other threads have signaled that they're done. The difference between them is that CyclicBarrier
is reusable, i.e. can be used multiple times, while CountDownLatch
is one-shot, you cannot reset the count.
Paraphrasing from the Javadocs:
A CountDownLatch is a synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.
A CyclicBarrier is a synchronization aid that allows a set of threads to all wait for each other to reach a common barrier point. CyclicBarriers are useful in programs involving a fixed sized party of threads that must occasionally wait for each other. The barrier is called cyclic because it can be re-used after the waiting threads are released.
https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CyclicBarrier.html
https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CountDownLatch.html
回答2:
Here is a simple demo of using CountDownLatch
to wait for all threads in the pool:
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class WaitForAllThreadsInPool {
private static int MAX_CYCLES = 10;
public static void main(String args[]) throws InterruptedException, IOException {
new WaitForAllThreadsInPool().apply(4);
}
public void apply(int numberOfThreads) {
ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
CountDownLatch cdl = new CountDownLatch(numberOfThreads);
class DijkstraTask implements Runnable {
private final String name;
private final CountDownLatch cdl;
private final Random rnd = new Random();
public DijkstraTask(String name, CountDownLatch cdl) {
this.name = name;
this.cdl = cdl;
}
@Override
public void run() {
calculateShortestDistances(1+ rnd.nextInt(MAX_CYCLES), cdl, name);
}
}
for (int t = 0; t < numberOfThreads; t++) {
executor.execute(new DijkstraTask("Task " + t, cdl));
}
//wait for all threads to finish
try {
cdl.await();
System.out.println("-all done-");
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
public void calculateShortestDistances(int numberOfWorkCycles, CountDownLatch cdl, String name) {
//simulate long process
for(int cycle = 1 ; cycle <= numberOfWorkCycles; cycle++){
System.out.println(name + " cycle "+ cycle + "/"+ numberOfWorkCycles );
try {
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
cdl.countDown(); //thread finished
}
}
Output sample:
Task 0 cycle 1/3
Task 1 cycle 1/2
Task 3 cycle 1/9
Task 2 cycle 1/3
Task 0 cycle 2/3
Task 1 cycle 2/2
Task 2 cycle 2/3
Task 3 cycle 2/9
Task 0 cycle 3/3
Task 2 cycle 3/3
Task 3 cycle 3/9
Task 3 cycle 4/9
Task 3 cycle 5/9
Task 3 cycle 6/9
Task 3 cycle 7/9
Task 3 cycle 8/9
Task 3 cycle 9/9
-all done-
回答3:
You can use invokeAll:
//Add task for each node
Collection<Callable<Object>> tasks = new ArrayList<>(numberOfThreads);
for (int t = 0; t < numberOfThreads; t++) {
tasks.add(Executors.callable(new DijkstraTask("Task " + t)));
}
//Wait until finished
executor.invokeAll(tasks);
来源:https://stackoverflow.com/questions/56974820/java-wait-in-a-loop-until-tasks-of-threadpoolexecutor-are-done-before-continuin