What is the correct way to perform constant-space nested loops in Haskell?

后端 未结 2 930
感动是毒
感动是毒 2021-02-02 11:43

There are two obvious, \"idiomatic\" ways to perform nested loops in Haskell: using the list monad or using forM_ to replace traditional fors. I\'ve se

相关标签:
2条回答
  • 2021-02-02 12:24

    Writing tight mutating code with GHC can be tricky sometimes. I'm going to write about a couple of different things, probably in a manner that is more rambling and tl;dr than I would prefer.

    For starters, we should use GHC 7.10 in any case, since otherwise the forM_ and list monad solutions never fuse.

    Also, I replaced MV.write with MV.unsafeWrite, partly because it's faster, but more importantly it reduces some of the clutter in the resultant Core. From now on runtime statistics refer to code with unsafeWrite.

    The dreaded let floating

    Even with GHC 7.10, we should first notice all those [0..times-1] and [0..side-1] expressions, because they will ruin performance every time if we don't take necessary steps. The issue is that they are constant ranges, and -ffull-laziness (which is enabled by default on -O) floats them out to top level. This prevents list fusion, and iterating over an Int# range is cheaper than iterating over a list of boxed Int-s anyway, so it's a really bad optimization.

    Let's see some runtimes in seconds for the unchanged (aside from using unsafeWrite) code. ghc -O2 -fllvm is used, and I use +RTS -s for timing.

    test_a: 1.6
    test_b: 6.2
    test_c: 0.6
    

    For GHC Core viewing I used ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures.

    In the case of test_a, the [0..99] ranges are lifted out:

    main4 :: [Int]
    main4 = eftInt 0 99 -- means "enumFromTo" for Int.
    

    although the outermost [0..9999] loop is fused into a tail-recursive helper:

    letrec {
              a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
              a3_s7xL =
                \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
                  case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
                  case x_X5zl of wild_X1S {
                    __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                    99999 -> (# ipv2_a4NA, () #)
                  }
                  }; }
    

    In the case of test_b, again only the [0..99] are lifted. However, test_b is much slower, because it has to build and sequence actual [IO ()] lists. At least GHC is sensible enough to only build a single [IO ()] for the two inner loops, and then perform sequencing it 10000 times.

     let {
              lvl7_s4M5 :: [IO ()]
              lvl7_s4M5 = -- omitted
            letrec {
              a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
              a2_s7Av =
                \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
                  letrec {
                    a3_s7Au
                      :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
                    a3_s7Au =
                      \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
                        case ds_a4Nu of _ {
                          [] ->
                            case x_a5xi of wild1_X1y {
                              __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
                              99999 -> (# eta1_X1c, () #)
                            };
                          : y_a4Nz ys_a4NA ->
                            case (y_a4Nz `cast` ...) eta1_X1c
                            of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
                            a3_s7Au ys_a4NA ipv2_a4Nf
                            }
                        }; } in
                  a3_s7Au lvl7_s4M5 eta_B1; } in
    -- omitted
    

    How can we remedy this? We could nuke the problem with {-# OPTIONS_GHC -fno-full-laziness #-}. This indeed helps a lot in our case:

    test_a: 0.5
    test_b: 0.48
    test_c: 0.5
    

    Alternatively, we could fiddle around with INLINE pragmas. Apparently inlining functions after the let floating is done preserves good performance. I found that GHC inlines our test functions even without a pragma, but an explicit pragma causes it to inline only after let floating. For example, this results in good performance without -fno-full-laziness:

    test_a mvec = 
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-1] $ \ y -> 
                forM_ [0..side-1] $ \ x -> 
                    MV.unsafeWrite mvec (y*side+x) 1
    {-# INLINE test_a #-}
    

    But inlining too early results in poor performance:

    test_a mvec = 
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-1] $ \ y -> 
                forM_ [0..side-1] $ \ x -> 
                    MV.unsafeWrite mvec (y*side+x) 1
    {-# INLINE [~2] test_a #-} -- "inline before the first phase please"
    

    The problem with this INLINE solution is that it's rather fragile in the face of GHC's floating onslaught. For example, manual inlining does not preserve performance. The following code is slow because similarly to INLINE [~2] it gives GHC a chance to float out:

    main = do
        let vec = V.generate (side*side) (const 0)
        mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-1] $ \ y -> 
                forM_ [0..side-1] $ \ x -> 
                    MV.unsafeWrite mvec (y*side+x) 1    
    

    So what should we do?

    First, I think using -fno-full-laziness is a perfectly viable and even preferable option for those who'd like to write high performance code and have a good idea what they are doing. For example, it's used in unordered-containers. With it we have more precise control over sharing, and we can always just float out or inline manually.

    For more regular code, I believe there's nothing wrong with using Control.Monad.Loop or any other package that provides the functionality. Many Haskell users are not scrupulous about depending on small "fringe" libraries. We can also just reimplement for, in a desired generality. For instance, the following performs just as well as the other solutions:

    for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m ()
    for init while step body = go init where
      go !i | while i = body i >> go (step i)
      go i = return ()
    {-# INLINE for #-}
    

    Looping in really constant space

    I was at first very puzzled by the +RTS -s data on heap allocation. test_a allocated non-trivially with -fno-full-laziness, and also test_c without full laziness, and these allocations scaled linearly with the number of times iterations, but test_b with full laziness allocated only for the vector:

    -- with -fno-full-laziness, no INLINE pragmas
    test_a: 242,521,008 bytes
    test_b: 121,008 bytes
    test_c: 121,008 bytes -- but 240,120,984 with full laziness!
    

    Also, INLINE pragmas for test_c did not help at all in this case.

    I spent some time trying to find signs of heap allocation in the Core for the relevant programs, without success, until the realization struck me: GHC stack frames are on the heap, including the frames of the main thread, and the functions that were doing heap allocation were essentially running the thrice-nested loops in at most three stack frames. The heap allocation registered by +RTS -s is just the constant popping and pushing of stack frames.

    This is pretty much apparent from the Core for the following code:

    {-# OPTIONS_GHC -fno-full-laziness #-}
    
    -- ...
    
    test_a mvec = 
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-1] $ \ y -> 
                forM_ [0..side-1] $ \ x -> 
                    MV.unsafeWrite mvec (y*side+x) 1
    main = do
        let vec = V.generate (side*side) (const 0)
        mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
        test_a mvec
    

    Which I'm including here in its glory. Feel free to skip.

    main1 :: State# RealWorld -> (# State# RealWorld, () #)
    main1 =
      \ (s_a5HK :: State# RealWorld) ->
        case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->
    
        -- start of vector creation ----------------------
        case tagToEnum# (># 10000 ww4_a5vr) of _ {
          False ->
            case newByteArray# 80000 (s_a5HK `cast` ...)
            of _ { (# ipv_a5fv, ipv1_a5fw #) ->
            letrec {
              $s$wa_s8jS
                :: Int#
                   -> Int#
                   -> State# (PrimState IO)
                   -> (# State# (PrimState IO), Int #)
              $s$wa_s8jS =
                \ (sc_s8jO :: Int#)
                  (sc1_s8jP :: Int#)
                  (sc2_s8jR :: State# (PrimState IO)) ->
                  case tagToEnum# (<# sc1_s8jP 10000) of _ {
                    False -> (# sc2_s8jR, I# sc_s8jO #);
                    True ->
                      case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                      of s'#_a5Gn { __DEFAULT ->
                      $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                      }
                  }; } in
            case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
            -- end of vector creation -------------------
    
            of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
            letrec {
              a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
              a2_s7MJ =
                \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
                  letrec {
                    a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                    a3_s7ME =
                      \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                        case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                        case writeIntArray#
                               (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                        of s'#_a5Gn { __DEFAULT ->
                        letrec {
                          a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                          a4_s7Mz =
                            \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                              case writeIntArray#
                                     (ipv1_a5fw `cast` ...)
                                     (+# (*# x1_X5Id 100) x2_X5J8)
                                     1
                                     (eta2_X1U `cast` ...)
                              of s'#1_X5Hf { __DEFAULT ->
                              case x2_X5J8 of wild_X2o {
                                __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                                99 -> (# s'#1_X5Hf `cast` ..., () #)
                              }
                              }; } in
                        case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                        of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                        case x1_X5Id of wild_X1e {
                          __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                          99 -> (# ipv2_a4QH, () #)
                        }
                        }
                        }
                        }; } in
                  case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                  case x_a5Ho of wild_X1a {
                    __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                    99999 -> (# ipv2_a4QH, () #)
                  }
                  }; } in
            a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
            }
            };
          True ->
            case error
                   (unpackAppendCString#
                      "Primitive.basicUnsafeNew: length to large: "#
                      (case $wshowSignedInt 0 10000 ([])
                       of _ { (# ww5_a5wm, ww6_a5wn #) ->
                       : ww5_a5wm ww6_a5wn
                       }))
            of wild_00 {
            }
        }
        }
    
    main :: IO ()
    main = main1 `cast` ...
    
    main2 :: State# RealWorld -> (# State# RealWorld, () #)
    main2 = runMainIO1 (main1 `cast` ...)
    
    main :: IO ()
    main = main2 `cast` ...
    

    We can also nicely demonstrate the allocation of frames the following way. Let's change test_a:

    test_a mvec = 
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-1] $ \ y -> 
                forM_ [0..side-50] $ \ x -> -- change here
                    MV.unsafeWrite mvec (y*side+x) 1
    

    Now the heap allocation stays exactly the same, because the innermost loop is tail-recursive and uses a single frame. With the following change, the heap allocation halves (to 124,921,008 bytes), because we push and pop half as many frames:

    test_a mvec = 
        forM_ [0..times-1] $ \ n -> 
            forM_ [0..side-50] $ \ y -> -- change here
                forM_ [0..side-1] $ \ x -> 
                    MV.unsafeWrite mvec (y*side+x) 1
    

    test_b and test_c (with no full laziness) instead compile to code that uses a nested case construct inside a single stack frame, and walks over the indices to see which one should be incremented. See the Core for the following main:

    {-# LANGUAGE BangPatterns #-} -- later I'll talk about this
    {-# OPTIONS_GHC -fno-full-laziness #-}
    
    main = do
        let vec = V.generate (side*side) (const 0)
        !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
        test_c mvec
    

    Voila:

    main1 :: State# RealWorld -> (# State# RealWorld, () #)
    main1 =
      \ (s_a5Iw :: State# RealWorld) ->
        case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->
    
        -- start of vector creation ----------------------
        case tagToEnum# (># 10000 ww4_a5vT) of _ {
          False ->
            case newByteArray# 80000 (s_a5Iw `cast` ...)
            of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
            letrec {
              $s$wa_s8ji
                :: Int#
                   -> Int#
                   -> State# (PrimState IO)
                   -> (# State# (PrimState IO), Int #)
              $s$wa_s8ji =
                \ (sc_s8je :: Int#)
                  (sc1_s8jf :: Int#)
                  (sc2_s8jh :: State# (PrimState IO)) ->
                  case tagToEnum# (<# sc1_s8jf 10000) of _ {
                    False -> (# sc2_s8jh, I# sc_s8je #);
                    True ->
                      case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                      of s'#_a5GP { __DEFAULT ->
                      $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                      }
                  }; } in
            case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
            of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
            case ipv7_a4MY of _ { I# dt4_a5xy ->
            -- end of vector creation
    
            letrec {
              a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
              a2_s7Q6 =
                \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
                  letrec {
                    a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                    a3_s7Q5 =
                      \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                        letrec {
                          a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                          a4_s7MZ =
                            \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                              case writeIntArray#
                                     (ipv1_a5g4 `cast` ...)
                                     (+# (*# x1_X5J9 100) x2_X5Jl)
                                     1
                                     (s1_X4Xb `cast` ...)
                              of s'#_a5GP { __DEFAULT ->
    
                              -- the interesting part! ------------------
                              case x2_X5Jl of wild_X1y {
                                __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x1_X5J9 of wild1_X1o {
                                    __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                    99 ->
                                      case x_a5HT of wild2_X1c {
                                        __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                        99999 -> (# s'#_a5GP `cast` ..., () #)
                                      }
                                  }
                              }
                              }; } in
                        a4_s7MZ 0 eta1_XP; } in
                  a3_s7Q5 0 eta_B1; } in
            a2_s7Q6 0 (ipv6_a4MX `cast` ...)
            }
            }
            };
          True ->
            case error
                   (unpackAppendCString#
                      "Primitive.basicUnsafeNew: length to large: "#
                      (case $wshowSignedInt 0 10000 ([])
                       of _ { (# ww5_a5wO, ww6_a5wP #) ->
                       : ww5_a5wO ww6_a5wP
                       }))
            of wild_00 {
            }
        }
        }
    
    main :: IO ()
    main = main1 `cast` ...
    
    main2 :: State# RealWorld -> (# State# RealWorld, () #)
    main2 = runMainIO1 (main1 `cast` ...)
    
    main :: IO ()
    main = main2 `cast` ...
    

    I have to admit that I basically don't know why some code avoids stack frame creation and some doesn't. I suspect that inlining from "the inside" out helps, and a quick inspection informed me that Control.Monad.Loop uses a CPS encoding, which might be relevant here, although the Monad.Loop solution is sensitive to let floating, and I couldn't determine on short notice from the Core why test_c with let floating fails to run in a single stack frame.

    Now, the performance benefit of running in a single stack frame is small. We've seen that test_b is only slightly faster than test_a. I include this detour in the answer because I found it edifying.

    The state hack and strict bindings

    The so-called state hack makes GHC aggressive in inlining into IO and ST actions. I think I should mention it here, because besides let floating this is the other thing that can thoroughly ruin performance.

    The state hack is enabled with optimizations -O, and can possibly slow down programs asymptotically. A simple example from Reid Barton:

    import Control.Monad
    import Debug.Trace
    
    expensive :: String -> String
    expensive x = trace "$$$" x
    
    main :: IO ()
    main = do
      str <- fmap expensive getLine
      replicateM_ 3 $ print str
    

    With GHC-7.10.2, this prints "$$$" once without optimizations but three times with -O2. And it seems that with GHC-7.10, we can't get rid of this behavior with -fno-state-hack (which is the subject of the linked ticket from Reid Barton).

    Strict monadic bindings reliably get rid of this problem:

    main :: IO ()
    main = do
      !str <- fmap expensive getLine
      replicateM_ 3 $ print str
    

    I think it's good habit to do strict bindings in IO and ST. And I have some experience (not definitive though; I'm far from being a GHC expert) that strict bindings are especially needed if we use -fno-full-laziness. Apparently full laziness can help get rid of some of the work duplication introduced by the inlining caused by the state hack; with test_b and no full laziness, omitting the strict binding on !mvec <- V.unsafeThaw vec caused a slight slowdown and extremely ugly Core output.

    0 讨论(0)
  • 2021-02-02 12:32

    In my experience forM_ [0..n-1] can perform well, but unfortunately it's not reliable. Just adding an INLINE pragma to test_a and using -O2 makes it run much faster (4s to 1s for me), but manually inlining it (copy paste) slows it down again.

    A more reliable function is is for from statistics which is implemented as

    -- | Simple for loop.  Counts from /start/ to /end/-1.
    for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
    for n0 !n f = loop n0
      where
        loop i | i == n    = return ()
               | otherwise = f i >> loop (i+1)
    {-# INLINE for #-}
    

    Using it looks similar to forM_ with lists:

    test_d :: MV.IOVector Int -> IO ()
    test_d mv =
      for 0 times $ \_ ->
        for 0 side $ \i ->
          for 0 side $ \j ->
            MV.unsafeWrite mv (i*side + j) 1
    

    but performs reliably well (0.85s for me) without any risk of allocating a list.

    0 讨论(0)
提交回复
热议问题