How to optimize this Haskell code summing up the primes in sublinear time?

后端 未结 4 1413
迷失自我
迷失自我 2021-02-18 23:11

Problem 10 from Project Euler is to find the sum of all the primes below given n.

I solved it simply by summing up the primes generated by the sieve of Eratosth

4条回答
  •  你的背包
    2021-02-18 23:45

    First as a baseline, the timings of the existing approaches on my machine:

    1. Original program posted in the question:

      time stack exec primorig
      95673602693282040
      
      real    0m4.601s
      user    0m4.387s
      sys     0m0.251s
      
    2. Second the version using Data.IntMap.Strict from here

      time stack exec primIntMapStrict
      95673602693282040
      
      real    0m2.775s
      user    0m2.753s
      sys     0m0.052s
      
    3. Shershs code with Data.Judy dropped in here

      time stack exec prim-hash2
      95673602693282040
      
      real    0m0.945s
      user    0m0.955s
      sys     0m0.028s
      
    4. Your python solution.

      I compiled it with

      python -O -m py_compile problem10.py
      

      and the timing:

      time python __pycache__/problem10.cpython-36.opt-1.pyc
      95673602693282040
      
      real    0m1.163s
      user    0m1.160s
      sys     0m0.003s
      
    5. Your C++ version:

      $ g++ -O2 --std=c++11 p10.cpp -o p10
      $ time ./p10
      sum(2000000000) = 95673602693282040
      
      real    0m0.314s
      user    0m0.310s
      sys     0m0.003s
      

    I didn't bother to provide a baseline for slow.hs, as I didn't want to wait for it to complete when run with an argument of 2*10^9.

    Subsecond performance

    The following program runs in under a second on my machine.

    It uses a hand rolled hashmap, which uses closed hashing with linear probing and uses some variant of knuths hashfunction, see here.

    Certainly it is somewhat tailored to the case, as the lookup function for example expects the searched keys to be present.

    Timings:

    time stack exec prim
    95673602693282040
    
    real    0m0.725s
    user    0m0.714s
    sys     0m0.047s
    

    First I implemented my hand rolled hashmap simply to hash the keys with

    key `mod` size
    

    and selected a size multiple times higher than the expected input, but the program took 22s or more to complete.

    Finally it was a matter of choosing a hash function which was good for the workload.

    Here is the program:

    import Data.Maybe
    import Control.Monad
    import Data.Array.IO
    import Data.Array.Base (unsafeRead)
    
    type Number = Int
    
    data Map = Map { keys :: IOUArray Int Number
                   , values :: IOUArray Int Number
                   , size :: !Int 
                   , factor :: !Int
                   }
    
    newMap :: Int -> Int -> IO Map
    newMap s f = do
      k <- newArray (0, s-1) 0
      v <- newArray (0, s-1) 0
      return $ Map k v s f 
    
    storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
    storeKey arr s f key = go ((key * f) `mod` s)
      where
        go :: Int -> IO Int
        go ind = do
          v <- readArray arr ind
          go2 v ind
        go2 v ind
          | v == 0    = do { writeArray arr ind key; return ind; }
          | v == key  = return ind
          | otherwise = go ((ind + 1) `mod` s)
    
    loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
    loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
      where
        go :: Int -> IO Int
        go ix = do
          v <- unsafeRead arr ix
          if v == key then return ix else go ((ix + 1) `mod` s)
    
    insertIntoMap :: Map -> (Number, Number) -> IO Map
    insertIntoMap m@(Map ks vs s f) (k, v) = do
      ix <- storeKey ks s f k
      writeArray vs ix v
      return m
    
    fromList :: Int -> Int -> [(Number, Number)] -> IO Map
    fromList s f xs = do
      m <- newMap s f
      foldM insertIntoMap m xs
    
    (!) :: Map -> Number -> IO Number
    (!) (Map ks vs s f) k = do
      ix <- loadKey ks s f k
      readArray vs ix
    
    mupdate :: Map -> Number -> (Number -> Number) -> IO ()
    mupdate (Map ks vs s fac) i f = do
      ix <- loadKey ks s fac i
      old <- readArray vs ix
      let x' = f old
      x' `seq` writeArray vs ix x'
    
    r' :: Number -> Number
    r'  = floor . sqrt . fromIntegral
    
    vs' :: Integral a => a -> a -> [a]
    vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  
    
    vss' n r = r + n `div` r -1
    
    list' :: Int -> Int -> [Number] -> IO Map
    list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
    
    problem10 :: Number -> IO Number
    problem10 n = do
          m <- list' (19*vss) (19*vss+7) vs
          nm <- sieve m 2 r vs
          nm ! n
        where vs = vs' n r
              vss = vss' n r
              r  = r' n
    
    sieve :: Map -> Number -> Number -> [Number] -> IO Map
    sieve m p r vs | p > r     = return m
                   | otherwise = do
                       v1 <- m ! p
                       v2 <- m ! (p - 1)
                       nm <- if v1 > v2 then update m vs p else return m
                       sieve nm (p + 1) r vs
    
    update :: Map -> [Number] -> Number -> IO Map
    update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
    
    decrease :: Number -> Map -> Number -> IO Map
    decrease p m k = do
      v <- sumOfSieved m k p
      mupdate m k (subtract v)
      return m
    
    sumOfSieved :: Map -> Number -> Number -> IO Number
    sumOfSieved m v p = do
      v1 <- m ! (v `div` p)
      v2 <- m ! (p - 1)
      return $ p * (v1 - v2)
    
    main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
    

    I am not a professional with hashing and that sort of stuff, so this can certainly be improved a lot. Maybe we Haskellers should improve the of the shelf hash maps or provide some simpler ones.

    My hashmap, Shershs code

    If I plug my hashmap in Shershs (see answer below) code, see here we are even down to

    time stack exec prim-hash2
    95673602693282040
    
    real    0m0.601s
    user    0m0.604s
    sys     0m0.034s
    

    Why is slow.hs slow?

    If you read through the source for the function insert in Data.HashTable.ST.Basic, you will see that it deletes the old key value pair and inserts a new one. It doesn't look up the "place" for the value and mutate it, as one might imagine, if one reads that it is a "mutable" hashtable. Here the hashtable itself is mutable, so you don't need to copy the whole hashtable for insertion of a new key value pair, but the value places for the pairs are not. I don't know if that is the whole story of slow.hs being slow, but my guess is, it is a pretty big part of it.

    A few minor improvements

    So that's the idea I followed while trying to improve your program the first time.

    See, you don't need a mutable mapping from keys to values. Your key set is fixed. You want a mapping from keys to mutable places. (Which is, by the way, what you get from C++ by default.)

    And so I tried to come up with that. I used IntMap IORef from Data.IntMap.Strict and Data.IORef first and got a timing of

    tack exec prim
    95673602693282040
    
    real    0m2.134s
    user    0m2.141s
    sys     0m0.028s
    

    I thought maybe it would help to work with unboxed values and to get that, I used IOUArray Int Int with 1 element each instead of IORef and got those timings:

    time stack exec prim
    95673602693282040
    
    real    0m2.015s
    user    0m2.018s
    sys     0m0.038s
    

    Not much of a difference and so I tried to get rid of bounds checking in the 1 element arrays by using unsafeRead and unsafeWrite and got a timing of

    time stack exec prim
    95673602693282040
    
    real    0m1.845s
    user    0m1.850s
    sys     0m0.030s
    

    which was the best I got using Data.IntMap.Strict.

    Of course I ran each program multiple times to see if the times are stable and the differences in run time aren't just noise.

    It looks like these are all just micro-optimizations.

    And here is the program that ran fastest for me without using a hand rolled data structure:

    import qualified Data.IntMap.Strict as M
    import Control.Monad
    import Data.Array.IO
    import Data.Array.Base (unsafeRead, unsafeWrite)
    
    type Number = Int
    type Place = IOUArray Number Number
    type Map = M.IntMap Place
    
    tupleToRef :: (Number, Number) -> IO (Number, Place)
    tupleToRef = traverse (newArray (0,0))
    
    insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
    insertRefs = traverse tupleToRef
    
    fromList :: [(Number, Number)] -> IO Map 
    fromList xs = M.fromList <$> insertRefs xs
    
    (!) :: Map -> Number -> IO Number
    (!) m i = unsafeRead (m M.! i) 0
    
    mupdate :: Map -> Number -> (Number -> Number) -> IO ()
    mupdate m i f = do
      let place = m M.! i
      old <- unsafeRead place 0
      let x' = f old
      -- make the application of f strict
      x' `seq` unsafeWrite place 0 x'
    
    r' :: Number -> Number
    r'  = floor . sqrt . fromIntegral
    
    vs' :: Integral a => a -> a -> [a]
    vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  
    
    list' :: [Number] -> IO Map
    list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
    
    problem10 :: Number -> IO Number
    problem10 n = do
          m <- list' vs
          nm <- sieve m 2 r vs
          nm ! n
        where vs = vs' n r
              r  = r' n
    
    sieve :: Map -> Number -> Number -> [Number] -> IO Map
    sieve m p r vs | p > r     = return m
                   | otherwise = do
                       v1 <- m ! p
                       v2 <- m ! (p - 1)
                       nm <- if v1 > v2 then update m vs p else return m
                       sieve nm (p + 1) r vs
    
    update :: Map -> [Number] -> Number -> IO Map
    update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
    
    decrease :: Number -> Map -> Number -> IO Map
    decrease p m k = do
      v <- sumOfSieved m k p
      mupdate m k (subtract v)
      return m
    
    sumOfSieved :: Map -> Number -> Number -> IO Number
    sumOfSieved m v p = do
      v1 <- m ! (v `div` p)
      v2 <- m ! (p - 1)
      return $ p * (v1 - v2)
    
    main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
    

    If you profile that, you see that it spends most of the time in the custom lookup function (!), don't know how to improve that further. Trying to inline (!) with {-# INLINE (!) #-} didn't yield better results; maybe ghc already did this.

提交回复
热议问题