LINQ performance Count vs Where and Count

前端 未结 6 1990
礼貌的吻别
礼貌的吻别 2020-12-08 10:14
public class Group
{
   public string Name { get; set; }
}  

Test:

List _groups = new List();

for (i         


        
相关标签:
6条回答
  • 2020-12-08 10:27

    Following on from Matthew Watson's answer:

    The reason iterating over a List<T> generates call instructions rather than callvirt, as used for IEnumerable<T>, is that the C# foreach statement is duck-typed.

    The C# Language Specification, section 8.8.4, says that the compiler 'determines whether the type X has an appropriate GetEnumerator method'. This is used in preference to an enumerable interface. Therefore the foreach statement here uses the overload of List<T>.GetEnumerator which returns a List<T>.Enumerator rather than the version that returns IEnumerable<T> or just IEnumerable.

    The compiler also checks that the type returned by GetEnumerator has a Current property and a MoveNext method that takes no arguments. For List<T>.Enumerator, these methods are not marked virtual, so the compiler can compile a direct call. In contrast, in IEnumerator<T> they are virtual so the compiler must generate a callvirt instruction. The extra overhead of calling through the virtual function table explains the difference in performance.

    0 讨论(0)
  • 2020-12-08 10:32

    It looks to me that the difference is in how the Linq extensions are coded. I suspect Where is using optimizations in the List<> class to speed up the operations, but Count just iterates through an IEnumerable<>.

    If you do the same process, but with an IEnumerable, both methods are close, with Where being slightly slower.

    List<Group> _groups = new List<Group>();
    
    for (int i = 0; i < 10000; i++)
    {
        var group = new Group();
    
        group.Name = i + "asdasdasd";
        _groups.Add(group);
    }
    
    IEnumerable<Group> _groupsEnumerable = from g in _groups select g;
    
    Stopwatch _stopwatch2 = new Stopwatch();
    
    _stopwatch2.Start();
    foreach (var group in _groups)
    {
        var count = _groupsEnumerable.Count(x => x.Name == group.Name);
    }
    _stopwatch2.Stop();
    
    Console.WriteLine(_stopwatch2.ElapsedMilliseconds);
    Stopwatch _stopwatch = new Stopwatch();
    
    _stopwatch.Start();
    foreach (var group in _groups)
    {
        var count = _groupsEnumerable.Where(x => x.Name == group.Name).Count();
    }
    _stopwatch.Stop();
    
    Console.WriteLine(_stopwatch.ElapsedMilliseconds);
    

    Where extension method. Notice the if (source is List<TSource>) case:

    public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        if (source == null)
        {
            throw Error.ArgumentNull("source");
        }
        if (predicate == null)
        {
            throw Error.ArgumentNull("predicate");
        }
        if (source is Enumerable.Iterator<TSource>)
        {
            return ((Enumerable.Iterator<TSource>)source).Where(predicate);
        }
        if (source is TSource[])
        {
            return new Enumerable.WhereArrayIterator<TSource>((TSource[])source, predicate);
        }
        if (source is List<TSource>)
        {
            return new Enumerable.WhereListIterator<TSource>((List<TSource>)source, predicate);
        }
        return new Enumerable.WhereEnumerableIterator<TSource>(source, predicate);
    }
    

    Count method. Just iterates through the IEnumerable:

    public static int Count<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        if (source == null)
        {
            throw Error.ArgumentNull("source");
        }
        if (predicate == null)
        {
            throw Error.ArgumentNull("predicate");
        }
        int num = 0;
        checked
        {
            foreach (TSource current in source)
            {
                if (predicate(current))
                {
                    num++;
                }
            }
            return num;
        }
    }
    
    0 讨论(0)
  • 2020-12-08 10:37

    My guess:

    .Where() uses special "WhereListIterator" to iterate over elements, Count() does not, as indicated by Wyatt Earp. The interesting thing is that the iterator is marked as "ngenable":

     [TargetedPatchingOptOut("Performance critical to inline this type of method across NGen image boundaries")]
     public WhereListIterator(List<TSource> source, Func<TSource, bool> predicate)
     {
       this.source = source;
       this.predicate = predicate;
     }
    

    This would probably mean that the "iterator" part runs as a "unmanaged code", while the Count() runs as a managed code. I don't know if that makes sense / how to prove it, but that's my 0.2cents.

    Also, if you rewrite the Count() to take care of List carefully,

    you can make it the same / even faster:

    public static class TestExt{
       public static int CountFaster<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate) {
           if (source == null) throw new Exception();
           if (predicate == null) throw new Exception();
    
           if(source is List<TSource>)
           {
                    int finalCount=0;
                    var list = (List<TSource>)source;
                    var count = list.Count;
                    for(var j = 0; j < count; j++){
                        if(predicate(list[j])) 
                            finalCount++;
                    }
                    return finalCount;
           }
    
    
           return source.Count(predicate);
       }
    

    }

    On my tests; after I've started using CountFaster(), the one who is called LATER wins(because of cold-startup).

    0 讨论(0)
  • 2020-12-08 10:38

    According to @Matthew Watson post I checked some behaviour. In my example "Where" always returned empty collection so Count was not even invoked on interface IEnumerable (which is significantly slower than enumerating on List elements). Instead of adding all groups with different name I added all items with the same name. Then Count is faster than Count + Method. This is because in Count approach we enumerate on interface IEnumerable over all items. In Method + Count approach if all items are identical, "Where" returns whole collection (casted to IEnumerable interface) and it invokes Count(), so Where invoke is redundant or I can say - it's slowing things down.

    All in all, specific situation in this example led me to conclusions that Method + Where is always faster but it is not true. If "Where" returns collection which is not much smaller than the original collection "Method + Where approach" will be slower.

    0 讨论(0)
  • 2020-12-08 10:39

    The crucial thing is in the implementation of Where() where it casts the IEnumerable to a List<T> if it can. Note the cast where WhereListIterator is constructed (this is from .Net source code obtained via reflection):

    public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate) {
        if (source is List<TSource>) return new WhereListIterator<TSource>((List<TSource>)source, predicate);
        return new WhereEnumerableIterator<TSource>(source, predicate);
    }
    

    I have verified this by copying (and simplifying where possible) the .Net implementations.

    Crucially, I implemented two versions of Count() - one called TestCount() where I use IEnumerable<T>, and one called TestListCount() where I cast the enumerable to List<T> before counting the items.

    This gives the same speedup as we see for the Where() operator which (as shown above) also casts to List<T> where it can.

    (This should be tried with a release build without a debugger attached.)

    This demonstrates that it is faster to use foreach to iterate over a List<T> compared to the same sequence represented via a IEnumerable<T>.

    Firstly, here's the complete test code:

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Linq;
    
    namespace Demo
    {
        public class Group
        {
            public string Name
            {
                get;
                set;
            }
        }
    
        internal static class Program
        {
            static void Main()
            {
                int dummy = 0;
                List<Group> groups = new List<Group>();
    
                for (int i = 0; i < 10000; i++)
                {
                    var group = new Group();
    
                    group.Name = i + "asdasdasd";
                    groups.Add(group);
                }
    
                Stopwatch stopwatch = new Stopwatch();
    
                for (int outer = 0; outer < 4; ++outer)
                {
                    stopwatch.Restart();
    
                    foreach (var group in groups)
                        dummy += TestWhere(groups, x => x.Name == group.Name).Count();
    
                    Console.WriteLine("Using TestWhere(): " + stopwatch.ElapsedMilliseconds);
    
                    stopwatch.Restart();
    
                    foreach (var group in groups)
                        dummy += TestCount(groups, x => x.Name == group.Name);
    
                    Console.WriteLine("Using TestCount(): " + stopwatch.ElapsedMilliseconds);
    
                    stopwatch.Restart();
    
                    foreach (var group in groups)
                        dummy += TestListCount(groups, x => x.Name == group.Name);
    
                    Console.WriteLine("Using TestListCount(): " + stopwatch.ElapsedMilliseconds);
                }
    
                Console.WriteLine("Total = " + dummy);
            }
    
            public static int TestCount<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
            {
                int count = 0;
    
                foreach (TSource element in source)
                {
                    if (predicate(element)) 
                        count++;
                }
    
                return count;
            }
    
            public static int TestListCount<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
            {
                return testListCount((List<TSource>) source, predicate);
            }
    
            private static int testListCount<TSource>(List<TSource> source, Func<TSource, bool> predicate)
            {
                int count = 0;
    
                foreach (TSource element in source)
                {
                    if (predicate(element))
                        count++;
                }
    
                return count;
            }
    
            public static IEnumerable<TSource> TestWhere<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
            {
                return new WhereListIterator<TSource>((List<TSource>)source, predicate);
            }
        }
    
        class WhereListIterator<TSource>: Iterator<TSource>
        {
            readonly Func<TSource, bool> predicate;
            List<TSource>.Enumerator enumerator;
    
            public WhereListIterator(List<TSource> source, Func<TSource, bool> predicate)
            {
                this.predicate = predicate;
                this.enumerator = source.GetEnumerator();
            }
    
            public override bool MoveNext()
            {
                while (enumerator.MoveNext())
                {
                    TSource item = enumerator.Current;
                    if (predicate(item))
                    {
                        current = item;
                        return true;
                    }
                }
                Dispose();
    
                return false;
            }
        }
    
        abstract class Iterator<TSource>: IEnumerable<TSource>, IEnumerator<TSource>
        {
            internal TSource current;
    
            public TSource Current
            {
                get
                {
                    return current;
                }
            }
    
            public virtual void Dispose()
            {
                current = default(TSource);
            }
    
            public IEnumerator<TSource> GetEnumerator()
            {
                return this;
            }
    
            public abstract bool MoveNext();
    
            object IEnumerator.Current
            {
                get
                {
                    return Current;
                }
            }
    
            IEnumerator IEnumerable.GetEnumerator()
            {
                return GetEnumerator();
            }
    
            void IEnumerator.Reset()
            {
                throw new NotImplementedException();
            }
        }
    }
    

    Now here's the IL generated for the two crucial methods, TestCount(): and testListCount(). Remember that the only difference between these is that TestCount() is using the IEnumerable<T> and testListCount() is using the same enumerable, but cast to its underlying List<T> type:

    TestCount():
    
    .method public hidebysig static int32 TestCount<TSource>(class [mscorlib]System.Collections.Generic.IEnumerable`1<!!TSource> source, class [mscorlib]System.Func`2<!!TSource, bool> predicate) cil managed
    {
        .maxstack 8
        .locals init (
            [0] int32 count,
            [1] !!TSource element,
            [2] class [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource> CS$5$0000)
        L_0000: ldc.i4.0 
        L_0001: stloc.0 
        L_0002: ldarg.0 
        L_0003: callvirt instance class [mscorlib]System.Collections.Generic.IEnumerator`1<!0> [mscorlib]System.Collections.Generic.IEnumerable`1<!!TSource>::GetEnumerator()
        L_0008: stloc.2 
        L_0009: br L_0025
        L_000e: ldloc.2 
        L_000f: callvirt instance !0 [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource>::get_Current()
        L_0014: stloc.1 
        L_0015: ldarg.1 
        L_0016: ldloc.1 
        L_0017: callvirt instance !1 [mscorlib]System.Func`2<!!TSource, bool>::Invoke(!0)
        L_001c: brfalse L_0025
        L_0021: ldloc.0 
        L_0022: ldc.i4.1 
        L_0023: add.ovf 
        L_0024: stloc.0 
        L_0025: ldloc.2 
        L_0026: callvirt instance bool [mscorlib]System.Collections.IEnumerator::MoveNext()
        L_002b: brtrue.s L_000e
        L_002d: leave L_003f
        L_0032: ldloc.2 
        L_0033: brfalse L_003e
        L_0038: ldloc.2 
        L_0039: callvirt instance void [mscorlib]System.IDisposable::Dispose()
        L_003e: endfinally 
        L_003f: ldloc.0 
        L_0040: ret 
        .try L_0009 to L_0032 finally handler L_0032 to L_003f
    }
    
    
    testListCount():
    
    .method private hidebysig static int32 testListCount<TSource>(class [mscorlib]System.Collections.Generic.List`1<!!TSource> source, class [mscorlib]System.Func`2<!!TSource, bool> predicate) cil managed
    {
        .maxstack 8
        .locals init (
            [0] int32 count,
            [1] !!TSource element,
            [2] valuetype [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource> CS$5$0000)
        L_0000: ldc.i4.0 
        L_0001: stloc.0 
        L_0002: ldarg.0 
        L_0003: callvirt instance valuetype [mscorlib]System.Collections.Generic.List`1/Enumerator<!0> [mscorlib]System.Collections.Generic.List`1<!!TSource>::GetEnumerator()
        L_0008: stloc.2 
        L_0009: br L_0026
        L_000e: ldloca.s CS$5$0000
        L_0010: call instance !0 [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::get_Current()
        L_0015: stloc.1 
        L_0016: ldarg.1 
        L_0017: ldloc.1 
        L_0018: callvirt instance !1 [mscorlib]System.Func`2<!!TSource, bool>::Invoke(!0)
        L_001d: brfalse L_0026
        L_0022: ldloc.0 
        L_0023: ldc.i4.1 
        L_0024: add.ovf 
        L_0025: stloc.0 
        L_0026: ldloca.s CS$5$0000
        L_0028: call instance bool [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::MoveNext()
        L_002d: brtrue.s L_000e
        L_002f: leave L_0042
        L_0034: ldloca.s CS$5$0000
        L_0036: constrained [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>
        L_003c: callvirt instance void [mscorlib]System.IDisposable::Dispose()
        L_0041: endfinally 
        L_0042: ldloc.0 
        L_0043: ret 
        .try L_0009 to L_0034 finally handler L_0034 to L_0042
    }
    

    I think that the important lines here is where it calls IEnumerator::GetCurrent() and IEnumerator::MoveNext().

    In the first case it is:

    callvirt instance !0 [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource>::get_Current()
    callvirt instance bool [mscorlib]System.Collections.IEnumerator::MoveNext()
    

    And in the second case it is:

    call instance !0 [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::get_Current()
    call instance bool [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::MoveNext()
    

    Importantly, in the second case a non-virtual call is being made - which can be significantly faster than a virtual call if it is in a loop (which it is, of course).

    0 讨论(0)
  • 2020-12-08 10:50

    Sarge Borsch gave the proper answer in the comments but without further explanations.

    The problems lies with the fact that the bytecode must be compiled to x86 by the JIT compiler on the first run. As a result your measure incorporates both what you wanted to test and the compilation time. And since most of the things used by the second test will have been compiled during the first test (the list enumerator, the Name property getter, etc), the first one is more impacted by the compilation.

    The solution is to do a "warm-up": you run your code once without doing measures, usually with just one iteration, simply to have it compiled. Then you start the stopwatch and run it again for real, with as many iterations as needed to get a long enough duration (one second for example).

    0 讨论(0)
提交回复
热议问题