I have some code of the following form:
static async Task DoSomething(int n)
{
...
}
static void RunThreads(int totalThreads, int throttle)
{
var task
Microsoft's Reactive Extensions (Rx) - NuGet "Rx-Main" - has this problem sorted very nicely.
Just do this:
static void RunThreads(int totalThreads, int throttle)
{
Observable
.Range(0, totalThreads)
.Select(n => Observable.FromAsync(() => DoSomething(n)))
.Merge(throttle)
.Wait();
}
Job done.
First, abstract away from threads. Especially since your operation is asynchronous, you shouldn't be thinking about "threads" at all. In the asynchronous world, you have tasks, and you can have a huge number of tasks compared to threads.
Throttling asynchronous code can be done using SemaphoreSlim
:
static async Task DoSomething(int n);
static void RunConcurrently(int total, int throttle)
{
var mutex = new SemaphoreSlim(throttle);
var tasks = Enumerable.Range(0, total).Select(async item =>
{
await mutex.WaitAsync();
try { await DoSomething(item); }
finally { mutex.Release(); }
});
Task.WhenAll(tasks).Wait();
}
Here are some extension method variations to build on Sriram Sakthivel answer.
In the usage example, calls to DoSomething
are being wrapped in an explicitly cast closure to allow passing arguments.
public static async Task RunMyThrottledTasks()
{
var myArgsSource = new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
await myArgsSource
.Select(a => (Func<Task<object>>)(() => DoSomething(a)))
.Throttle(2);
}
public static async Task<object> DoSomething(int arg)
{
// Await some async calls that need arg..
// ..then return result async Task..
return new object();
}
public static async Task<IEnumerable<T>> Throttle<T>(IEnumerable<Func<Task<T>>> toRun, int throttleTo)
{
var running = new List<Task<T>>(throttleTo);
var completed = new List<Task<T>>(toRun.Count());
foreach(var taskToRun in toRun)
{
running.Add(taskToRun());
if(running.Count == throttleTo)
{
var comTask = await Task.WhenAny(running);
running.Remove(comTask);
completed.Add(comTask);
}
}
return completed.Select(t => t.Result);
}
public static async Task Throttle(this IEnumerable<Func<Task>> toRun, int throttleTo)
{
var running = new List<Task>(throttleTo);
foreach(var taskToRun in toRun)
{
running.Add(taskToRun());
if(running.Count == throttleTo)
{
var comTask = await Task.WhenAny(running);
running.Remove(comTask);
}
}
}
The simplest option IMO is to use TPL Dataflow. You just create an ActionBLock
, limit it by the desired parallelism and start posting items into it. It makes sure to only run a certain amount of tasks at the same time, and when a task completes, it starts executing the next item:
async Task RunAsync(int totalThreads, int throttle)
{
var block = new ActionBlock<int>(
DoSomething,
new ExecutionDataFlowOptions { MaxDegreeOfParallelism = throttle });
for (var n = 0; n < totalThreads; n++)
{
block.Post(n);
}
block.Complete();
await block.Completion;
}
If I understand correctly, you can start tasks limited number of tasks mentioned by throttle
parameter and wait for them to finish before starting next..
To wait for all started tasks to complete before starting new tasks, use the following implementation.
static async Task RunThreads(int totalThreads, int throttle)
{
var tasks = new List<Task>();
for (var n = 0; n < totalThreads; n++)
{
var task = DoSomething(n);
tasks.Add(task);
if (tasks.Count == throttle)
{
await Task.WhenAll(tasks);
tasks.Clear();
}
}
await Task.WhenAll(tasks); // wait for remaining
}
To add tasks as on when it is completed you can use the following code
static async Task RunThreads(int totalThreads, int throttle)
{
var tasks = new List<Task>();
for (var n = 0; n < totalThreads; n++)
{
var task = DoSomething(n);
tasks.Add(task);
if (tasks.Count == throttle)
{
var completed = await Task.WhenAny(tasks);
tasks.Remove(completed);
}
}
await Task.WhenAll(tasks); // all threads must complete
}
Stephen Toub gives the following example for throttling in his The Task-based Asynchronous Pattern document.
const int CONCURRENCY_LEVEL = 15;
Uri [] urls = …;
int nextIndex = 0;
var imageTasks = new List<Task<Bitmap>>();
while(nextIndex < CONCURRENCY_LEVEL && nextIndex < urls.Length)
{
imageTasks.Add(GetBitmapAsync(urls[nextIndex]));
nextIndex++;
}
while(imageTasks.Count > 0)
{
try
{
Task<Bitmap> imageTask = await Task.WhenAny(imageTasks);
imageTasks.Remove(imageTask);
Bitmap image = await imageTask;
panel.AddImage(image);
}
catch(Exception exc) { Log(exc); }
if (nextIndex < urls.Length)
{
imageTasks.Add(GetBitmapAsync(urls[nextIndex]));
nextIndex++;
}
}