Why this program seems not to be fusing properly?

前端 未结 1 1729
小鲜肉
小鲜肉 2021-01-13 04:26

I was under suspect that a given program wasn\'t fusing as it would and made this test to confirm:

module Main where

import qualified Data.Vector.Unboxed as         


        
1条回答
  •  星月不相逢
    2021-01-13 05:00

    Note: I'm using GHC 7.10.3 and stack 1.1.2 on Windows (x64), so your times might differ.

    TL;DR

    Make sure to inline your functions if you want to use stream fusion.

    How to fuse a stream

    The stream fusion relies heavily on the optimizer and rewrite rules, at least with the vector package. So let's check which versions of your program are optimized well.

    Minimal version (1 incAll)

    Let's start simple. We start by reducing the program to the minimum:

    -- SOBase.hs
    module Main where
    
    import qualified Data.Vector.Unboxed as V
    
    main :: IO ()
    main = do
    
      let size = 100000000 :: Int
      let array = V.replicate size 0 :: V.Vector Int
      let incAll = V.map (+ 1)
    
      print 
        . V.sum     
        . incAll    
        $ array
    

    Let's compile it and dump GHC's generated core:

    $ stack ghc --package vector -- -O2 SOBase.hs -ddump-simpl -dsuppress-all
    
    main2
    main2 =
      case (runSTRep main3) `cast` ...
      of _ { Vector ipv_s6b2 ipv1_s6b3 ipv2_s6b4 ->
      letrec {
        $s$wfoldlM'_loop_s9wM
        $s$wfoldlM'_loop_s9wM =
          \ sc_s9wK sc1_s9wL ->
            case tagToEnum# (>=# sc1_s9wL ipv1_s6b3) of _ {
              False ->
                case indexIntArray# ipv2_s6b4 (+# ipv_s6b2 sc1_s9wL)
                of wild_a5ju { __DEFAULT ->
                $s$wfoldlM'_loop_s9wM (+# sc_s9wK (+# wild_a5ju 1)) (+# sc1_s9wL 1)
                };
              True -> sc_s9wK
            }; } in
      case $s$wfoldlM'_loop_s9wM 0 0 of ww_s94k { __DEFAULT ->
      case $wshowSignedInt 0 ww_s94k ([])
      of _ { (# ww5_a5fH, ww6_a5fI #) ->
      : ww5_a5fH ww6_a5fI
      }
      }
      }
    

    Let's make that a little bit prettier:

    main2 = let foldLoop s n 
                  | n < size  = foldLoop (s + (vec ! n + 1)) (n + 1)
                  | otherwise = s
            in print (foldLoop 0 0)
    

    The incAll has been inlined into the function:

    case indexIntArray# ipv2_s6b4 (+# ipv_s6b2 sc1_s9wL)
                    of wild_a5ju { __DEFAULT ->
                    $s$wfoldlM'_loop_s9wM (+# sc_s9wK (+# wild_a5ju 1)) (+# sc1_s9wL 1)
                                                      ^^^^^^^^^^^^^^^^
    

    More calls (3 incAlls)

    Let's use incAll more often:

     -- SO3.hs
    module Main where
    
    import qualified Data.Vector.Unboxed as V
    
    main :: IO ()
    main = do
    
      let size = 100000000 :: Int
      let array = V.replicate size 0 :: V.Vector Int
      let incAll = V.map (+ 1)
    
      print
        . V.sum
    
        . incAll
        . incAll
        . incAll
    
        $ array
    

    What does our core contain now?

    $wincAll
    $wincAll =
      \ ww_s999 ww1_s99a ww2_s99b ->
        runSTRep
          (\ @ s_a4Rs s1_a4Rt ->
             case tagToEnum# (<# ww1_s99a 0) of _ {
               False ->
                 case divInt# 9223372036854775807 8 of ww4_a5fa { __DEFAULT ->
                 case tagToEnum# (># ww1_s99a ww4_a5fa) of _ {
                   False ->
                     case newByteArray# (*# ww1_s99a 8) (s1_a4Rt `cast` ...)
                     of _ { (# ipv_a5dy, ipv1_a5dz #) ->
                     letrec {
                       $s$wa_s9DR
                       $s$wa_s9DR =
                         \ sc_s9DN sc1_s9DO sc2_s9DQ ->
                           case tagToEnum# (>=# sc1_s9DO ww1_s99a) of _ {
                             False ->
                               case indexIntArray# ww2_s99b (+# ww_s999 sc1_s9DO)
                               of wild_a5jF { __DEFAULT ->
                               case writeIntArray#
                                      ipv1_a5dz sc_s9DN (+# wild_a5jF 1) (sc2_s9DQ `cast` ...)
                               of s'#_a6Cg { __DEFAULT ->
                               $s$wa_s9DR (+# sc_s9DN 1) (+# sc1_s9DO 1) (s'#_a6Cg `cast` ...)
                               }
                               };
                             True -> (# sc2_s9DQ, I# sc_s9DN #)
                           }; } in
                     case $s$wa_s9DR 0 0 (ipv_a5dy `cast` ...)
                     of _ { (# ipv6_a4Nw, ipv7_a4Nx #) ->
                     case ipv7_a4Nx of _ { I# dt4_a5gC ->
                     case unsafeFreezeByteArray# ipv1_a5dz (ipv6_a4Nw `cast` ...)
                     of _ { (# ipv2_a52B, ipv3_a52C #) ->
                     (# ipv2_a52B `cast` ...,
                        (Vector 0 dt4_a5gC ipv3_a52C) `cast` ... #)
                     }
                     }
                     }
                     };
                   True -> case main4 ww1_s99a of wild_00 { }
                 }
                 };
               True -> case main3 ww1_s99a of wild_00 { }
             })
    
    ....
    
    main2
    main2 =
      case (runSTRep main5) `cast` ...
      of _ { Vector ww1_s991 ww2_s992 ww3_s993 ->
      case ($wincAll ww1_s991 ww2_s992 ww3_s993) `cast` ...
    --      ^^^^^^^^ oh
      of _ { Vector ww5_X99T ww6_X99V ww7_X99X ->
      case ($wincAll ww5_X99T ww6_X99V ww7_X99X) `cast` ...
    --      ^^^^^^^^ oh
      of _ { Vector ww9_X99Y ww10_X9a0 ww11_X9a2 ->
      case ($wincAll ww9_X99Y ww10_X9a0 ww11_X9a2) `cast` ...
    --      ^^^^^^^^ oh
      of _ { Vector ipv_s6cG ipv1_s6cH ipv2_s6cI ->
      letrec {
        $s$wfoldlM'_loop_s9Du
        $s$wfoldlM'_loop_s9Du =
          \ sc_s9Ds sc1_s9Dt ->
            case tagToEnum# (>=# sc1_s9Dt ipv1_s6cH) of _ {
              False ->
                case indexIntArray# ipv2_s6cI (+# ipv_s6cG sc1_s9Dt)
                of wild_a5jx { __DEFAULT ->
                $s$wfoldlM'_loop_s9Du (+# sc_s9Ds wild_a5jx) (+# sc1_s9Dt 1)
                };
              True -> sc_s9Ds
            }; } in
      case $s$wfoldlM'_loop_s9Du 0 0 of ww12_s99s { __DEFAULT ->
      case $wshowSignedInt 0 ww12_s99s ([])
      of _ { (# ww14_a5fK, ww15_a5fL #) ->
      : ww14_a5fK ww15_a5fL
      }
      }
      }
      }
      }
      }
    

    The function is not inlined anymore! Since it isn't inlined, the stream fusion cannot kick in.

    Inlining the function (3 incAlls)

    Let's add an INLINE pragma:

    -- SO3I.hs
    module Main where
    
    import qualified Data.Vector.Unboxed as V
    
    main :: IO ()
    main = do
    
      let size = 100000000 :: Int
      let array = V.replicate size 0 :: V.Vector Int
      let {-# INLINE incAll #-}
          incAll = V.map (+1)
      print 
        . V.sum 
    
        . incAll 
        . incAll 
        . incAll 
    
        $ array
    
    stack ghc --package vector -- -O2 -ddump-simpl SO3I.hs
    

    How does the main now look like?

    main2                                                                         
    main2 =                                                                       
      case (runSTRep main3) `cast` ...                                            
      of _ { Vector ipv_s6bG ipv1_s6bH ipv2_s6bI ->                               
      letrec {                                                                    
        $s$wfoldlM'_loop_s9z7                                                     
        $s$wfoldlM'_loop_s9z7 =                                                   
          \ sc_s9z5 sc1_s9z6 ->                                                   
            case tagToEnum# (>=# sc1_s9z6 ipv1_s6bH) of _ {                       
              False ->                                                            
                case indexIntArray# ipv2_s6bI (+# ipv_s6bG sc1_s9z6)              
                of wild_a5jC { __DEFAULT ->                                       
                $s$wfoldlM'_loop_s9z7                                             
                  (+# sc_s9z5 (+# (+# (+# wild_a5jC 1) 1) 1)) (+# sc1_s9z6 1)     
                };                                                                
              True -> sc_s9z5                                                     
            }; } in                                                               
      case $s$wfoldlM'_loop_s9z7 0 0 of ww_s96F { __DEFAULT ->                    
      case $wshowSignedInt 0 ww_s96F ([])                                         
      of _ { (# ww5_a5fP, ww6_a5fQ #) ->                                          
      : ww5_a5fP ww6_a5fQ                                                         
      }                                                                           
      }                                                                           
      }                                                                           
    

    Great. incAll has been inlined, as can be seen here:

    (+# sc_s9z5 (+# (+# (+# wild_a5jC 1) 1) 1)) (+# sc1_s9z6 1)     
                                      ^  ^  ^
    

    So the problem was that incAll wasn't inlined, therefore you didn't end up with

    V.sum . V.map (+1) . V.map (+1) . V.map (+1)
    

    Your original program (now inlined, 32 incAlls)

    Last but not least, let's try your original program again, this time with inline. Is everything fixed? Let's have a look at the core:

    main2
    main2 =
      case (runSTRep main3) `cast` ...
      of _ { Vector ipv_s6xF ipv1_s6xG ipv2_s6xH ->
      letrec {
        $s$wfoldlM'_loop_sajT
        $s$wfoldlM'_loop_sajT =
          \ sc_sajR sc1_sajS ->
            case tagToEnum# (>=# sc1_sajS ipv1_s6xG) of _ {
              False ->
                case indexIntArray# ipv2_s6xH (+# ipv_s6xF sc1_sajS)
                of wild_a5mq { __DEFAULT ->
                $s$wfoldlM'_loop_sajT
                  (+#
                     sc_sajR
                     (+#
                        (+#
                           (+#
                              (+#
                                 (+#
                                    (+#
                                       (+#
                                          (+#
                                             (+#
                                                (+#
                                                   (+#
                                                      (+#
                                                         (+#
                                                            (+#
                                                               (+#
                                                                  (+#
                                                                     (+#
                                                                        (+#
                                                                           (+#
                                                                              (+#
                                                                                 (+#
                                                                                    (+#
                                                                                       (+#
                                                                                          (+#
                                                                                             (+#
                                                                                                (+#
                                                                                                   (+#
                                                                                                      (+#
                                                                                                         (+#
                                                                                                            (+#
                                                                                                               (+#
                                                                                                                  (+#
                                                                                                                     wild_a5mq
                                                                                                                     1)
                                                                                                                  1)
                                                                                                               1)
                                                                                                            1)
                                                                                                         1)
                                                                                                      1)
                                                                                                   1)
                                                                                                1)
                                                                                             1)
                                                                                          1)
                                                                                       1)
                                                                                    1)
                                                                                 1)
                                                                              1)
                                                                           1)
                                                                        1)
                                                                     1)
                                                                  1)
                                                               1)
                                                            1)
                                                         1)
                                                      1)
                                                   1)
                                                1)
                                             1)
                                          1)
                                       1)
                                    1)
                                 1)
                              1)
                           1)
                        1))
                  (+# sc1_sajS 1)
                };
              True -> sc_sajR
            }; } in
      case $s$wfoldlM'_loop_sajT 0 0 of ww_s9Rr { __DEFAULT ->
      case $wshowSignedInt 0 ww_s9Rr ([])
      of _ { (# ww5_a5iD, ww6_a5iE #) ->
      : ww5_a5iD ww6_a5iE
      }
      }
      }
    

    Well, yes. But GHC isn't smart enough to put (+1) . (+1) to (+2) and so on. Is it actually faster?

    $ stack ghc --package vector -- -O2 SO.hs && SO.exe +RTS -s
      26,400,052,464 bytes allocated in the heap                                             
               9,736 bytes copied during GC                                                  
         800,026,736 bytes maximum residency (2 sample(s))                                   
              61,328 bytes maximum slop                                                      
                1527 MB total memory in use (0 MB lost due to fragmentation)                 
    
                                         Tot time (elapsed)  Avg pause  Max pause            
      Gen  0        32 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s            
      Gen  1         2 colls,     0 par    0.000s   0.089s     0.0446s    0.0890s            
    
      INIT    time    0.000s  (  0.000s elapsed)                                             
      MUT     time    4.453s  (  4.616s elapsed)                                             
      GC      time    0.000s  (  0.090s elapsed)                                             
      EXIT    time    0.000s  (  0.089s elapsed)                                             
      Total   time    4.453s  (  4.795s elapsed)                                             
    
      %GC     time       0.0%  (1.9% elapsed)                                                
    
      Alloc rate    5,928,432,834 bytes per MUT second                                       
    
      Productivity 100.0% of total user, 92.9% of total elapsed                              
    

    4 seconds for your original program. And for the inlined one?

    $ stack ghc --package vector -- -O2 SOFixed.hs && SOFixed.exe +RTS -s
    3200000000
         800,048,112 bytes allocated in the heap
               4,352 bytes copied during GC
              42,664 bytes maximum residency (1 sample(s))
              18,776 bytes maximum slop
                 764 MB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
      Gen  1         1 colls,     0 par    0.000s   0.045s     0.0452s    0.0452s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    0.188s  (  0.224s elapsed)
      GC      time    0.000s  (  0.045s elapsed)
      EXIT    time    0.000s  (  0.045s elapsed)
      Total   time    0.188s  (  0.315s elapsed)
    
      %GC     time       0.0%  (14.4% elapsed)
    
      Alloc rate    4,266,923,264 bytes per MUT second
    
      Productivity 100.0% of total user, 59.6% of total elapsed
    

    0.1 seconds. Great! By the way, all the (+1) calls get optimized into a single addq $32,... down the line.

    0 讨论(0)
提交回复
热议问题