I have a custom \"CachedEnumerable\" class (inspired by Caching IEnumerable) that I need to make thread safe for my asp.net core web app.
Is the following implementa
The access to cache, yes it is thread safe, only one thread per time can read from _cache object.
But in that way you can't assure that all threads gets elements in the same order as they access to GetEnumerator.
Check these two exaples, if the behavior is what you expect, you can use lock in that way.
Example 1:
THREAD1 Calls GetEnumerator
THREAD1 Initialize T current;
THREAD2 Calls GetEnumerator
THREAD2 Initialize T current;
THREAD2 LOCK THREAD
THREAD1 WAIT
THREAD2 read from cache safely _cache[0]
THREAD2 index++
THREAD2 UNLOCK
THREAD1 LOCK
THREAD1 read from cache safely _cache[1]
THREAD1 i++
THREAD1 UNLOCK
THREAD2 yield return current;
THREAD1 yield return current;
Example 2:
THREAD2 Initialize T current;
THREAD2 LOCK THREAD
THREAD2 read from cache safely
THREAD2 UNLOCK
THREAD1 Initialize T current;
THREAD1 LOCK THREAD
THREAD1 read from cache safely
THREAD1 UNLOCK
THREAD1 yield return current;
THREAD2 yield return current;
Your class is not thread safe, because shared state is mutated in unprotected regions inside your class. The unprotected regions are:
Dispose
methodThe shared state is:
_enumerator
private field_cache
private fieldCachingComplete
public propertySome other issues regarding your class:
IDisposable
creates the responsibility to the caller to dispose your class. There is no need for IEnumerable
s to be disposable. In the contrary IEnumerator
s are disposable, but there is language support for their automatic disposal (feature of foreach
statement).IEnumerable
(ElementAt
, Count
etc). Maybe you intended to implement a CachedList
instead? Without implementing the IList<T>
interface, LINQ methods like Count()
and ToArray()
cannot take advantage of your extended functionality, and will use the slow path like they do with plain vanilla IEnumerable
s.Update: I just noticed another thread-safety issue. This one is related to the public IEnumerator<T> GetEnumerator()
method. The enumerator is compiler-generated, since the method is an iterator (utilizes yield return
). Compiler-generated enumerators are not thread safe. Consider this code for example:
var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
int count = 0;
while (enumerator.MoveNext())
{
count++;
}
Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);
Four threads are using concurrently the same IEnumerator
. The enumerable has 1,000,000 items. You may expect that each thread would enumerate ~250,000 items, but that's not what happens.
Output:
Task #1 count: 0
Task #4 count: 0
Task #3 count: 0
Task #2 count: 1000000
The MoveNext
in the line while (enumerator.MoveNext())
is not your safe MoveNext
. It is the compiler-generated unsafe MoveNext
. Although unsafe, it includes a mechanism intended probably for dealing with exceptions, that marks temporarily the enumerator as finished before calling the externally provided code. So when multiple threads are calling the MoveNext
concurrently, all but the first will get a return value of false
, and will terminate instantly the enumeration, having completed zero loops. To solve this you must probably code your own IEnumerator
class.
Update: Actually my last point about thread-safe enumeration is a bit unfair, because enumerating with the IEnumerator
interface is an inherently unsafe operation, which is impossible to fix without the cooperation of the calling code. This is because obtaining the next element is not an atomic operation, since it involves two steps (call MoveNext()
+ read Current
). So your thread-safety concerns are limited to the protection of the internal state of your class (fields _enumerator
, _cache
and CachingComplete
). These are left unprotected only in the constructor and in the Dispose
method, but I suppose that the normal use of your class may not follow code paths that create the race conditions that would result to internal state corruption.
Personally I would prefer to take care of these code paths too, and I wouldn't let it to the whims of chance.
Update: I wrote a cache for IAsyncEnumerable
s, to demonstrate an alternative technique. The enumeration of the source IAsyncEnumerable
is not driven by the callers, using locks or semaphores to obtain exclusive access, but by a separate worker-task. The first caller starts the worker-task. Each caller at first yields all items that are already cached, and then awaits for more items, or for a notification that there are no more items. As notification mechanism I used a TaskCompletionSource<bool>. A lock
is still used to ensure that all access to shared resources is synchronized.
public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly object _locker = new object();
private IAsyncEnumerable<T> _source;
private Task _sourceEnumerationTask;
private List<T> _buffer;
private TaskCompletionSource<bool> _moveNextTCS;
private Exception _sourceEnumerationException;
private int _sourceEnumerationVersion; // Incremented on exception
public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
{
_source = source ?? throw new ArgumentNullException(nameof(source));
}
public async IAsyncEnumerator<T> GetAsyncEnumerator(
CancellationToken cancellationToken = default)
{
lock (_locker)
{
if (_sourceEnumerationTask == null)
{
_buffer = new List<T>();
_moveNextTCS = new TaskCompletionSource<bool>();
_sourceEnumerationTask = Task.Run(
() => EnumerateSourceAsync(cancellationToken));
}
}
int index = 0;
int localVersion = -1;
while (true)
{
T current = default;
Task<bool> moveNextTask = null;
lock (_locker)
{
if (localVersion == -1)
{
localVersion = _sourceEnumerationVersion;
}
else if (_sourceEnumerationVersion != localVersion)
{
ExceptionDispatchInfo
.Capture(_sourceEnumerationException).Throw();
}
if (index < _buffer.Count)
{
current = _buffer[index];
index++;
}
else
{
moveNextTask = _moveNextTCS.Task;
}
}
if (moveNextTask == null)
{
yield return current;
continue;
}
var moved = await moveNextTask;
if (!moved) yield break;
lock (_locker)
{
current = _buffer[index];
index++;
}
yield return current;
}
}
private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
{
TaskCompletionSource<bool> localMoveNextTCS;
try
{
await foreach (var item in _source.WithCancellation(cancellationToken))
{
lock (_locker)
{
_buffer.Add(item);
localMoveNextTCS = _moveNextTCS;
_moveNextTCS = new TaskCompletionSource<bool>();
}
localMoveNextTCS.SetResult(true);
}
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_buffer.TrimExcess();
_source = null;
}
localMoveNextTCS.SetResult(false);
}
catch (Exception ex)
{
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_sourceEnumerationException = ex;
_sourceEnumerationVersion++;
_sourceEnumerationTask = null;
}
localMoveNextTCS.SetException(ex);
}
}
}
This implementation follows a specific strategy for dealing with exceptions. If an exception occurs while enumerating the source IAsyncEnumerable
, the exception will be propagated to all current callers, the currently used IAsyncEnumerator
will be discarded, and the incomplete cached data will be discarded too. A new worker-task may start again later, when the next enumeration request is received.