Trying to create an efficient algorithm for a function in Haskell

前端 未结 3 1183
-上瘾入骨i
-上瘾入骨i 2021-01-16 12:59

I\'m looking for an efficient polynomial-time solution to the following problem:

Implement a recursive function node x y for calculating the (x,y)-th number in a num

相关标签:
3条回答
  • 2021-01-16 13:16

    You're thinking in terms of outgoing paths, when you should be thinking in terms of incoming paths. Your recursive step is currently looking for nodes from below, instead of above.

    0 讨论(0)
  • 2021-01-16 13:18

    First of all, sorry if this is long. I wanted to explain the step by step thought process.

    To start off with, you need one crucial fact: You can represent the "answer" at each "index" by a list of paths. For all the zeros, this is [[]], for your base case it is [[1]], and for example, for 0,2 it is [[6,1,1],[6,1,1],[6,1,1]]. This may seem like some redundancy, but it simplifies things down the road. Then, extracting the answer is head . head if the list is non empty, or const 0 if it is.

    This is very useful because you can store the answer as a list of rows (the first row would be '[[1]], [], [] ...) and the results of any given row depend only on the previous row.

    Secondly, this problem is symmetrical. This is pretty obvious.

    The first thing we will do will mirror the definition of fib very closely:

    type Path = [[Integer]]
    
    triangle' :: [[Path]]
    triangle' = ([[1]] : repeat []) : map f triangle' 
    

    We know this must be close to correct, since the 2nd row will depend on the first row only, the third on the 2nd only, etc. So the result will be

    ([[1]] : repeat []) : f ([[1]] : repeat []) : f ....
    

    Now we just need to know what f is. Firstly, its type: [Path] -> [Path]. Quite simply, given the previous row, return the next row.

    Now you may see another problem arising. Each invocation of f needs to know how many columns in the current row. We could actually count the length of non-null elements in the previous row, but it is simpler to pass the parameter directly, so we change map f triangle' to zipWith f [1..] triangle', giving f the type Int -> [Path] -> [Path].

    f needs to handle one special case and one general case. The special case is x=0, in this case we simply treat the x+1,y-1 and x-1,y-1 recursions the same, and otherwise is identical to gn. Lets make two functions, g0 and gn which handle these two cases.

    The actually computation of gn is easy. We know for some x we need the elements x-1, x, x+1 of the previous row. So if we drop x-1 elements before giving the previous row to the xth invocation of gn, gn can just take the first 3 elements and it will have what it needs. We write this as follows:

    f :: Int -> [Path] -> [Path]
    f n ps = g0 ps : map (gn . flip drop ps) [0..n-1] ++ repeat []
    

    The repeat [] at the end should be obvious: for indices outside the triangle, the result is 0.

    Now writing g0 and gs is really quite simple:

    g0 :: [Path] -> Path 
    g0 (a:b:_) =  map (s:) q 
      where 
        s = sum . concat $ q
        q = b ++ a ++ b 
    
    gn :: [Path] -> Path 
    gn (a:b:c:_) = map (s:) q 
      where 
        s = sum . concat $ q
        q = a ++ b ++ c
    

    On my machine this version is about 3-4 times faster than the fastest version I could write with normal recursion and memoization.

    The rest is just printing or pulling out the number you want.

    triangle :: Int -> Int -> Integer
    triangle x y = case (triangle' !! y) !! (abs x) of 
                     [] -> 0
                     xs -> head $ head xs 
    
    triList :: Int -> Int -> Path
    triList x y = (triangle' !! y) !! (abs x) 
    
    printTri :: Int -> Int -> IO ()
    printTri width height = 
      putStrLn $ unlines $ map unwords 
       [[ p $ triangle x y | x <- [-x0..x0]] | y <- [0..height]]
          where maxLen = length $ show $ triangle 0 height 
                x0 = width `div` 2
                p = printf $ "%" ++ show maxLen ++ "d " 
    
    0 讨论(0)
  • 2021-01-16 13:28

    I believe your problem is a bit more complicated than your example code suggests. First, let's be clear about some definitions here:

    Let pathCount x y be the number of paths that end at (x, y). We have

    pathCount :: Int -> Int -> Integer
    pathCount x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathCount (x + d) (y - 1) | d <- [-1..1]]
    

    Now let's pathSum x y be the sum of all paths that end in (x, y). We have:

    pathSum :: Int -> Int -> Integer
    pathSum x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathSum (x + d) (y - 1) + node x y * pathCount (x + d) (y - 1)
                         | d <- [-1..1] ]
    

    With this helper, we can finally define node x y properly:

    node :: Int -> Int -> Integer
    node x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathSum (x + d) (y - 1) | d <- [-1..1]]
    

    This algorithm as such is exponential time in its current form. We can however add memoization to make the number of additions quadratic. The memoize package on Hackage makes this easy as pie. Full example:

    import Control.Monad
    import Data.List (intercalate)
    import Data.Function.Memoize (memoize2)
    
    node' :: Int -> Int -> Integer
    node' x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathSum (x + d) (y - 1) | d <- [-1..1]]
    node = memoize2 node'
    
    pathCount' :: Int -> Int -> Integer
    pathCount' x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathCount (x + d) (y - 1) | d <- [-1..1]]
    pathCount = memoize2 pathCount'
    
    pathSum' :: Int -> Int -> Integer
    pathSum' x y
      | y == 0 = if x == 0 then 1 else 0
      | otherwise = sum [ pathSum (x + d) (y - 1) + node x y * pathCount (x + d) (y - 1)
                         | d <- [-1..1] ]
    pathSum = memoize2 pathSum'
    
    main =
      forM_ [0..n] $ \y ->
         putStrLn $ intercalate " " $ map (show . flip node y) [-n..n]
      where n = 5
    

    Output:

    0 0 0 0 0 1 0 0 0 0 0
    0 0 0 0 1 1 1 0 0 0 0
    0 0 0 2 4 6 4 2 0 0 0
    0 0 4 16 40 48 40 16 4 0 0
    0 8 72 352 728 944 728 352 72 8 0
    16 376 4248 16608 35128 43632 35128 16608 4248 376 16
    

    As you can see the algorithm the size of the numbers will get out of hands rather quickly. So the runtime is not O(n^2), while the number of arithmetic operations is.

    0 讨论(0)
提交回复
热议问题