Computing the mean of a list efficiently in Haskell

后端 未结 6 1930
傲寒
傲寒 2021-02-04 11:21

I\'ve designed a function to compute the mean of a list. Although it works fine, but I think it may not be the best solution due to it takes two functions rather than one. Is it

相关标签:
6条回答
  • 2021-02-04 11:37

    For those who are curious to know what glowcoder's and Assaf's approach would look like in Haskell, here's one translation:

    avg [] = 0
    avg x@(t:ts) = let xlen = toRational $ length x
                       tslen = toRational $ length ts
                       prevAvg = avg ts
                   in (toRational t) / xlen + prevAvg * tslen / xlen
    

    This way ensures that each step has the "average so far" correctly calculated, but does so at the cost of a whole bunch of redundant multiplying/dividing by lengths, and very inefficient calculations of length at each step. No seasoned Haskeller would write it this way.

    An only slightly better way is:

    avg2 [] = 0
    avg2 x = fst $ avg_ x
        where 
          avg_ [] = (toRational 0, toRational 0)
          avg_ (t:ts) = let
               (prevAvg, prevLen) = avg_ ts
               curLen = prevLen + 1
               curAvg = (toRational t) / curLen + prevAvg * prevLen / curLen
            in (curAvg, curLen)
    

    This avoids repeated length calculation. But it requires a helper function, which is precisely what the original poster is trying to avoid. And it still requires a whole bunch of canceling out of length terms.

    To avoid the cancelling out of lengths, we can just build up the sum and length and divide at the end:

    avg3 [] = 0
    avg3 x = (toRational total) / (toRational len)
        where 
          (total, len) = avg_ x
          avg_ [] = (0, 0)
          avg_ (t:ts) = let 
              (prevSum, prevLen) = avg_ ts
           in (prevSum + t, prevLen + 1)
    

    And this can be much more succinctly written as a foldr:

    avg4 [] = 0
    avg4 x = (toRational total) / (toRational len)
        where
          (total, len) = foldr avg_ (0,0) x
          avg_ t (prevSum, prevLen) = (prevSum + t, prevLen + 1)
    

    which can be further simplified as per the posts above.

    Fold really is the way to go here.

    0 讨论(0)
  • 2021-02-04 11:38

    While I am not sure whether or not it would be 'best' to write it in one function, it can be done as follows:

    If you know the length (lets call it 'n' here) in advance its easy - you can calculate how much each value 'adds' to the average; that is going to be value/length. Since avg(x1, x2, x3) = sum(x1, x2, x3)/length = (x1 + x2 + x3)/3 = x1/3 + x2/3 + x2/3

    If you don't know the length in advance, its a little trickier:

    lets say we use the list {x1,x2,x3} without knowing its n=3.

    first iteration would just be x1 (since we assume its only n=1) second iteration would add x2/2 and divide the existing average by 2 so now we have x1/2 + x2/2

    after the third iteration we have n=3 and we would want to have x1/3 +x2/3 + x3/3 but we have x1/2 + x2/2

    so we would need to multiply by (n-1) and divide by n to get x1/3 + x2/3 and to that we just add the current value (x3) divided by n to end up with x1/3 + x2/3 + x3/3

    Generally:

    given an average (arithmetic mean - avg) for n-1 items, if you want to add one item(newval) to the average your equation will be:

    avg*(n-1)/n + newval/n. The equation can be proven mathematically using induction.

    Hope this helps.

    *note this solution is less efficient than simply summing the variables and dividing by the total length as you do in your example.

    0 讨论(0)
  • 2021-02-04 11:43

    About the best you can do is this version:

    import qualified Data.Vector.Unboxed as U
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean :: U.Vector Double -> Double
    mean xs = s / fromIntegral n
      where
        Pair n s       = U.foldl' k (Pair 0 0) xs
        k (Pair n s) x = Pair (n+1) (s+x)
    
    main = print (mean $ U.enumFromN 1 (10^7))
    

    It fuses to an optimal loop in Core (the best Haskell you could write):

    main_$s$wfoldlM'_loop :: Int#
                                  -> Double#
                                  -> Double#
                                  -> Int#
                                  -> (# Int#, Double# #)    
    main_$s$wfoldlM'_loop =
      \ (sc_s1nH :: Int#)
        (sc1_s1nI :: Double#)
        (sc2_s1nJ :: Double#)
        (sc3_s1nK :: Int#) ->
        case ># sc_s1nH 0 of _ {
          False -> (# sc3_s1nK, sc2_s1nJ #);
          True ->
            main_$s$wfoldlM'_loop
              (-# sc_s1nH 1)
              (+## sc1_s1nI 1.0)
              (+## sc2_s1nJ sc1_s1nI)
              (+# sc3_s1nK 1)
        }
    

    And the following assembly:

    Main_mainzuzdszdwfoldlMzqzuloop_info:
    .Lc1pN:
            testq %r14,%r14
            jg .Lc1pQ
            movq %rsi,%rbx
            movsd %xmm6,%xmm5
            jmp *(%rbp)
    .Lc1pQ:
            leaq 1(%rsi),%rax
            movsd %xmm6,%xmm0
            addsd %xmm5,%xmm0
            movsd %xmm5,%xmm7
            addsd .Ln1pS(%rip),%xmm7
            decq %r14
            movsd %xmm7,%xmm5
            movsd %xmm0,%xmm6
            movq %rax,%rsi
            jmp Main_mainzuzdszdwfoldlMzqzuloop_info
    

    Based on Data.Vector. For example,

    $ ghc -Odph --make A.hs -fforce-recomp
    [1 of 1] Compiling Main             ( A.hs, A.o )
    Linking A ...
    $ time ./A
    5000000.5
    ./A  0.04s user 0.00s system 93% cpu 0.046 total
    

    See the efficient implementations in the statistics package.

    0 讨论(0)
  • 2021-02-04 11:47

    Your solution is good, using two functions is not worse than one. Still, you might put the tail recursive function in a where clause.

    But if you want to do it in one line:

    calcMeanList = uncurry (/) . foldr (\e (s,c) -> (e+s,c+1)) (0,0)
    
    0 讨论(0)
  • 2021-02-04 11:49

    To follow up on Don's 2010 reply, on GHC 8.0.2 we can do much better. First let's try his version.

    module Main (main) where
    
    import System.CPUTime.Rdtsc (rdtsc)
    import Text.Printf (printf)
    import qualified Data.Vector.Unboxed as U
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean' :: U.Vector Double -> Double
    mean' xs = s / fromIntegral n
      where
        Pair n s       = U.foldl' k (Pair 0 0) xs
        k (Pair n s) x = Pair (n+1) (s+x)
    
    main :: IO ()
    main = do
      s <- rdtsc
      let r = mean' (U.enumFromN 1 30000000)
      e <- seq r rdtsc
      print (e - s, r)
    

    This gives us

    [nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
    [1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
    Linking MeanD ...
    (372877482,1.50000005e7)
         240,104,176 bytes allocated in the heap
               6,832 bytes copied during GC
              44,384 bytes maximum residency (1 sample(s))
              25,248 bytes maximum slop
                 230 MB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
      Gen  1         1 colls,     0 par    0.006s   0.006s     0.0062s    0.0062s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    0.087s  (  0.087s elapsed)
      GC      time    0.006s  (  0.006s elapsed)
      EXIT    time    0.006s  (  0.006s elapsed)
      Total   time    0.100s  (  0.099s elapsed)
    
      %GC     time       6.2%  (6.2% elapsed)
    
      Alloc rate    2,761,447,559 bytes per MUT second
    
      Productivity  93.8% of total user, 93.8% of total elapsed
    

    However the code is simple: ideally there should be no need for vector: optimal code should be possible from just inlining the list generation. Luckily GHC can do this for us[0].

    module Main (main) where
    
    import System.CPUTime.Rdtsc (rdtsc)
    import Text.Printf (printf)
    import Data.List (foldl')
    
    data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
    
    mean' :: [Double] -> Double
    mean' xs = v / fromIntegral l
      where
        Pair l v = foldl' f (Pair 0 0) xs
        f (Pair l' v') x = Pair (l' + 1) (v' + x)
    
    main :: IO ()
    main = do
      s <- rdtsc
      let r = mean' $ fromIntegral <$> [1 :: Int .. 30000000]
          -- This is slow!
          -- r = mean' [1 .. 30000000]
      e <- seq r rdtsc
      print (e - s, r)
    

    This gives us:

    [nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
    [1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
    Linking MeanD ...
    (128434754,1.50000005e7)
             104,064 bytes allocated in the heap
               3,480 bytes copied during GC
              44,384 bytes maximum residency (1 sample(s))
              17,056 bytes maximum slop
                   1 MB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         0 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
      Gen  1         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    0.032s  (  0.032s elapsed)
      GC      time    0.000s  (  0.000s elapsed)
      EXIT    time    0.000s  (  0.000s elapsed)
      Total   time    0.033s  (  0.032s elapsed)
    
      %GC     time       0.1%  (0.1% elapsed)
    
      Alloc rate    3,244,739 bytes per MUT second
    
      Productivity  99.8% of total user, 99.8% of total elapsed
    

    [0]: Notice how I had to map fromIntegral: without this, GHC fails to eliminate [Double] and the solution is much slower. That is somewhat sad: I don't understand why GHC fails to inline/decides it does not need to without this. If you do have genuine collection of fractionals, then this hack won't work for you and vector may still be necessary.

    0 讨论(0)
  • 2021-02-04 12:01

    When I saw your question, I immediately thought "you want a fold there!"

    And sure enough, a similar question has been asked before on StackOverflow, and this answer has a very performant solution, which you can test in an interactive environment like GHCi:

    import Data.List
    
    let avg l = let (t,n) = foldl' (\(b,c) a -> (a+b,c+1)) (0,0) l 
                in realToFrac(t)/realToFrac(n)
    
    avg ([1,2,3,4]::[Int])
    2.5
    avg ([1,2,3,4]::[Double])
    2.5
    
    0 讨论(0)
提交回复
热议问题