Relation between `DList` and `[]` with Codensity

后端 未结 2 928
梦毁少年i
梦毁少年i 2021-02-05 06:41

I\'ve been experimenting with Codensity lately which is supposed to relate DList with [] among other things. Anyway, I\'ve never found cod

2条回答
  •  灰色年华
    2021-02-05 07:25

    TL;DR: DList for (++) serves the same purpose as Codensity for (>>=): reassociating the operators to the right.

    This is beneficial, because for both, (++) and (>>=), left associated computations (can) exhibit quadratic runtime behaviour.

    1. The full story

    The plan is as follows:

    • We go step by step through an example for (++) and (>>=), emonstrating the problem with associativity.
    • We use CPS to avoid quadratic complexity with DList and Codensity
    • Aftermath and Bonus (Generalize from (++) to (<>))

    2. The problem: quadratic runtime behaviour

    2a. List (++)

    Keep in mind that while I am using (++) as an example, this is valid for other functions as well, if they work analogous to (++).

    So let's first look at the problem with lists. The concat operation for lists is commonly defined as:

    (++) []     ys = ys
    (++) (x:xs) ys = x : xs ++ ys
    

    which means that (++) will always walk the first argument from start to end. To see when this is a problem consider the following two computations:

    as, bs, cs:: [Int]
    
    rightAssoc :: [Int]
    rightAssoc = (as ++ (bs ++ cs))
    
    leftAssoc :: [Int]
    leftAssoc = ((as ++ bs) ++ cs)
    

    Let's start with rightAssoc and walk through the evaluation.

    as = [1,2]
    bs = [3,4]
    cs = [5,6]
    rightAssoc = ([1,2] ++ ([3,4] ++ [5,6]))
               -- pattern match gives (1:[2]) for first arg
               = 1 : ([2] ++ ([3,4] ++ [5,6]))
               -- pattern match gives (2:[]) for first arg
               = 1 : 2 : ([] ++ ([3,4] ++ [5,6]))
               -- first case of (++)
               = 1 : 2 : ([3,4] ++ [5,6])
               = 1 : 2 : 3 : ([4] ++ [5,6])
               = 1 : 2 : 3 : 4 : ([] ++ [5,6])
               = 1 : 2 : 3 : 4 : [5,6]
               = [1,2,3,4,5,6]
    

    So we have to walk over as and bs.

    Okay that was not too bad, let's continue to leftAssoc:

    as = [1,2]
    bs = [3,4]
    cs = [5,6]
    leftAssoc = (([1,2] ++ [3,4]) ++ [5,6])
              = ((1 : ([2] ++ [3,4])) ++ [5,6])
              = ((1 : 2 : ([] ++ [3,4])) ++ [5,6])
              = ((1 : 2 : [3,4]) ++ [5,6])
              = ([1,2,3,4] ++ [5,6])
              -- uh oh
              = 1 : ([2,3,4] ++ [5,6])
              = 1 : 2 : ([3,4] ++ [5,6])
              = 1 : 2 : 3 : ([4] ++ [5,6])
              = 1 : 2 : 3 : 4 : ([] ++ [5,6])
              = 1 : 2 : 3 : 4 : [5,6]
              = [1,2,3,4,5,6]
    

    Uh oh, did you see that we had to walk over as twice? Once as [1,2] and then again inside as ++ bs = [1,2,3,4]. With each further operand that is wrongly associated, the list on the left of (++) which we have to traverse completely each time will grow longer in each step, leading to quadratic runtime behaviour.

    So as you see above left-associated (++) will destroy performance. Which leads us to:

    2b. Free monad (>>=)

    Keep in mind that while I am using Free as an example, this is also the case for other monads, e.g. the instance for Tree behaves like this, too

    First, we use the naive Free type:

    data Free f a = Pure a | Free (f (Free f a))
    

    Instead of (++), we look at (>>=) which is defined as and use (>>=) in prefix form:

    instance Functor f => Monad (Free f) where
      return = Pure
      (>>=) (Pure a) f = f a
      (>>=) (Free m) f = Free ((>>= f) <$> m)
    

    If you compare this with the definition of (++) from 2a above, you can see that the definition of (>>=) again looks at the first argument. That raises a first concern, will this perform as bad as in the (++) case when associated wrongly? Well, let's see, I use Identity here for simplicity but the choice of the functor is not the important fact here:

    -- specialized to 'Free'
    liftF :: Functor f => f a -> Free f a
    liftF fa = Free (Pure <$> fa)
    
    x :: Free Identity Int
    x = liftF (Identity 20) = Free (Identity (Pure 20))
    
    f :: Int -> Free Identity Int
    f x = liftF (Identity (x+1)) = Free (Identity (Pure (x+1)))
    
    g :: Int -> Free Identity Int
    g x = liftF (Identity (x*2)) = Free (Identity (Pure (x*2)))
    
    rightAssoc :: Free Identity Int
    rightAssoc = (x >>= \x -> (f x >>= g))
    
    leftAssoc :: Free Identity Int
    leftAssoc = ((x >>= f) >>= g)
    

    We again start with the rightAssoc variant first:

    rightAssoc = (x >>= \x -> (f x >>= g))
                        ~~~
               -- definition of x
               = ((Free (Identity (Pure 20))) >>= \x -> (f x >>= g))
                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               -- second case of definition for 'Free's (>>=)
               = Free ((>>= \x -> (f x >>= g)) <$> Identity (Pure 20))
                       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               -- (<$>) for Identity
               = Free (Identity ((Pure 20) >>= \x -> (f x >>= g)))
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               -- first case of the definition for 'Free's (>>=)
               = Free (Identity (f 20 >>= g))
                                 ~~~~
               = Free (Identity ((Free (Identity (Pure 21))) >>= g))
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               -- second case of definition for 'Free's (>>=)
               = Free (Identity (Free ((>>= g) <$> Identity (Pure 21))))
                                       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = Free (Identity (Free (Identity ((Pure 21) >>= g))))
                                                 ~~~~~~~~~~~~~~~
               = Free (Identity (Free (Identity (g 21))))
                                                 ~~~~
               = Free (Identity (Free (Identity (Free (Identity (Pure 42))))))
    

    Puh, okay I added ~~~~ under the expression that is reduced in the next step for clarity. If you squint hard enough, you may see some familiarity from 2a's' case for rightAssoc: we walk the two first arguments (now x and f instead of as and bs) arguments once. Without wasting further time, here is leftAssoc:

    leftAssoc = ((x >>= f) >>= g)
                 ~~~
              = ((Free (Identity (Pure 20)) >>= f) >>= g)
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = (Free ((>>= f) <$> Identity (Pure 20)) >>= g)
                       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = (Free (Identity ((Pure 20) >>= f)) >>= g)
                                 ~~~~~~~~~~~~~~~
              = (Free (Identity (f 20)) >>= g)
                                 ~~~~
              = (Free (Identity (Free (Identity (Pure 21)))) >>= g)
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = Free ((>>= g) <$> (Identity (Free (Identity (Pure 21)))))
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              -- uh oh
              = Free (Identity (Free (Identity (Pure 21)) >>= g))
                                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = Free (Identity (Free ((>>= g) <$> Identity (Pure 21))))
                                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = Free (Identity (Free (Identity ((Pure 21) >>= g))))
                                                ~~~~~~~~~~~~~~~~
              = Free (Identity (Free (Identity (g 21))))
                                                ~~~~
              = Free (Identity (Free (Identity (Free (Identity (Pure 42))))))
    

    If you look close, after the uh oh we have to tear down the intermediate structure again, just like in the (++) case (also marked with uh oh).

    2c. Result so far

    In both cases, leftAssoc leads to quadratic runtime behaviour, because we rebuild the first argument several times and tear it down right again for the next operation. This means that at each step in the evaluation we have to build and tear down a growing intermediate structure --- bad.

    3. The relation between DList and Codensity

    This is where we will discover the relation between DList and Codensity. Each one solves the problem of wrongly associated computations seen above by using CPS to effectively reassociate to the right.

    3a. DList

    First we introduce the definition of DList and append:

    newtype DList a = DL { unDL :: [a] -> [a] }
    
    append :: DList a -> DList a -> DList a
    append xs ys = DL (unDL xs . unDL ys)
    
    fromList :: [a] -> DList a
    fromList = DL . (++)
    
    toList :: DList a -> [a]
    toList = ($[]) . unDL
    

    and now our old friends:

    as,bs,cs :: DList Int
    as = fromList [1,2] = DL ([1,2] ++)
    bs = fromList [3,4] = DL ([3,4] ++)
    cs = fromList [5,6] = DL ([5,6] ++)
    
    rightAssoc :: [Int]
    rightAssoc = toList $ as `append` (bs `append` cs)
    
    leftAssoc :: [Int]
    leftAssoc = toList $ ((as `append` bs) `append` cs)
    

    Evaluation is roughly as follows:

    rightAssoc = toList $ (DL ([1,2] ++)) `append` (bs `append` cs)
               = toList $ DL $ unDL (DL ([1,2] ++)) . unDL (bs `append` cs)
                               ~~~~~~~~~~~~~~~~~~~~
               = toList $ DL $ ([1,2] ++) . unDL (bs `append` cs)
                                                  ~~
               = toList $ DL $ ([1,2] ++) . unDL ((DL ([3,4] ++)) `append` cs)
                                                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = toList $ DL $ ([1,2] ++) . unDL (DL $ unDL (DL ([3,4] ++)) . unDL cs)
                                                       ~~~~~~~~~~~~~~~~~~~~
               = toList $ DL $ ([1,2] ++) . unDL (DL $ ([3,4] ++) . unDL cs)
                                                                         ~~
               = toList $ DL $ ([1,2] ++) . unDL (DL $ ([3,4] ++) . unDL (DL ([5,6] ++)))
               = toList $ DL $ ([1,2] ++) . unDL (DL $ ([3,4] ++) . ([5,6] ++))
                                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = toList $ DL $ ([1,2] ++) . (([3,4] ++) . ([5,6] ++))
                 ~~~~~~
               -- definition of toList
               = ($[]) . unDL $ DL $ ([1,2] ++) . (([3,4] ++) . ([5,6] ++))
                         ~~~~~~~~~
               -- unDL . DL == id
               = ($[]) $ (([1,2] ++) . (([3,4] ++) . ([5,6] ++)))
               -- move ($[]) to end
               = (([1,2] ++) . (([3,4] ++) . ([5,6] ++))) []
               -- def: (.) g f x = g (f x)
               = (([1,2] ++) ((([3,4] ++) . ([5,6] ++)) []))
               = (([1,2] ++) (([3,4] ++) (([5,6] ++) [])))
               -- drop unnecessary parens
               = (([1,2] ++) (([3,4] ++) ([5,6] ++ [])))
               = ([1,2] ++ ([3,4] ++ ([5,6] ++ [])))
                                      ~~~~~~~~~~~
               -- (xs ++ []) == xs
               = ([1,2] ++ ([3,4] ++ ([5,6])))
               = (as ++ (bs ++ cs))
    

    Hah! The result is exactly the same as rightAssoc from 2a. Allright, with tension building up, we move on to leftAssoc:

    leftAssoc = toList $ ((as `append` bs) `append` cs)
              = toList $ (((DL ([1,2]++)) `append` bs) `append` cs)
              = toList $ ((DL (unDL (DL ([1,2]++)) . unDL bs)) `append` cs)
              = toList $ ((DL (unDL (DL ([1,2]++)) . unDL (DL ([3,4]++)))) `append` cs)
              = toList $ ((DL (([1,2]++) . ([3,4]++))) `append` cs)
              = toList $ (DL (unDL (DL (([1,2]++) . ([3,4]++))) . unDL cs))
              = toList $ (DL (unDL (DL (([1,2]++) . ([3,4]++))) . unDL (DL ([5,6]++))))
              = toList $ (DL ((([1,2]++) . ([3,4]++)) . ([5,6]++)))
              = ($[]) . unDL $ (DL ((([1,2]++) . ([3,4]++)) . ([5,6]++)))
              = ($[]) ((([1,2]++) . ([3,4]++)) . ([5,6]++))
              = ((([1,2]++) . ([3,4]++)) . ([5,6]++)) []
              -- expand (f . g) to \x -> f (g x)
              = ((\x -> ([1,2]++) (([3,4]++) x)) . ([5,6]++)) []
              = ((\x -> ([1,2]++) (([3,4]++) x)) (([5,6]++) []))
              -- apply lambda
              = ((([1,2]++) (([3,4]++) (([5,6]++) []))))
              = ([1,2] ++ ([3,4] ++ [5,6]))
              = as',bs',cs' ~ versions of 2a with no prime
              = (as' ++ (bs' ++ cs'))
    

    Heureka! The result is associated correctly (to the right), no quadratic slowdown.

    3b. Codensity

    Okay if you've come to this point you must be seriously interested, that's good, because so am I :). We start with the definition and Monad instance of Codensity (with abbreviated names):

    newtype Codensity m a = C { run :: forall b. (a -> m b) -> m b }
    
    instance Monad (Codensity f) where
      return x = C (\k -> k x)
      m >>= k = C (\c -> run m (\a -> run (k a) c))
    
    -- hidden as a instance for `MonadTrans`
    liftCodensity :: Monad m => m a -> Codensity m a
    liftCodensity m = C (m >>=)
    
    lowerCodensity :: Monad m => Codensity m a -> m a
    lowerCodensity a = run a return
    

    I guess you know what comes next:

    x :: Codensity (Free Identity) Int
    x = liftCodensity (Free (Identity (Pure 20)))
      = C (Free (Identity (Pure 20)) >>=)
      -- note the similarity to (DL (as ++))
      -- with DL ~ Codensity and (++) ~ (>>=) !
    
    f :: Int -> Codensity (Free Identity) Int
    f x = liftCodensity (Free (Identity (Pure (x+1))))
        = C (Free (Identity (Pure (x+1))) >>=)
    
    g :: Int -> Codensity (Free Identity) Int
    g x = liftCodensity (Free (Identity (Pure (x*2))))
        = C (Free (Identity (Pure (x*2))) >>=)
    
    rightAssoc :: Free Identity Int
    rightAssoc = lowerCodensity (x >>= \x -> (f x >>= g))
    
    leftAssoc :: Free Identity Int
    leftAssoc = lowerCodensity ((x >>= f) >>= g)
    

    Before we go through the evaluation once again, you might be interested in the comparison of append from DList and (>>=) from Codensity (unDL ~ run), go ahead and do that if you want, I'll wait for you.

    Okay we start with rightAssoc:

    rightAssoc = lowerCodensity (x >>= \x -> (f x >>= g))
                                ~~~
               -- def of x
               = lowerCodensity ((C (Free (Identity (Pure 20)) >>=)) >>= \x -> (f x >>= g))
               -- (>>=) of codensity
               = lowerCodensity (C (\c -> run (C (Free (Identity (Pure 20)) >>=)) (\a -> run ((\x -> (f x >>= g)) a) c)))
               -- run . C == id
               = lowerCodensity (C (\c -> Free (Identity (Pure 20)) >>= \a -> run ((\x -> (f x >>= g)) a) c))
               -- substitute x' for 'Free (Identity (Pure 20))' (same as only x from 2b)
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (f x >>= g)) a) c))
                                                                    ~~~
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (Free (Identity (Pure (x+1))) >>=)) >>= g) a) c))
                                                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> run (C (Free (Identity (Pure (x+1))) >>=)) (\a2 -> run (g a2) c2)))) a) c))
                                                                               ~~~~~~
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (Free (Identity (Pure (x+1))) >>=) (\a2 -> run (g a2) c2)))) a) c))
               -- again, substitute f' for '\x -> Free (Identity (Pure (x+1)))' (same as only f from 2b)
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> run (g a2) c2)))) a) c))
                                                                                                       ~~~~
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> run (C (Free (Identity (Pure (a2*2))) >>=)) c2)))) a) c))
                                                                                                  ~~~~~~
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> (Free (Identity (Pure (a2*2))) >>=) c2)))) a) c))
               -- one last time, substitute g' (g from 2b)
               = lowerCodensity (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> (g' a2 >>=) c2)))) a) c))
               -- def of lowerCodensity
               = run (C (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> (g' a2 >>=) c2)))) a) c)) return
               = (\c -> x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> (g' a2 >>=) c2)))) a) c) return
               = (x' >>= \a -> run ((\x -> (C (\c2 -> (f' x >>=) (\a2 -> (g' a2 >>=) c2)))) a) return)
                                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = (x' >>= \a -> run (C (\c2 -> (f' a >>=) (\a2 -> (g' a2 >>=) c2))) return)
                               ~~~~~~
               = (x' >>= \a -> (\c2 -> (f' a >>=) (\a2 -> (g' a2 >>=) c2)) return)
                                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
               = (x' >>= \a -> (f' a >>=) (\a2 -> g' a2 >>= return))
               -- m >>= return ~ m
               = (x' >>= \a -> (f' a >>=) (\a2 -> g' a2))
               -- m >>= (\x -> f x) ~ m >>= f
               = (x' >>= \a -> (f' a >>= g'))
               -- rename a to x
               = (x' >>= \x -> (f' x >>= g'))
    

    And we can now see that the (>>=)s are associated to the right, this is not yet particularly astonishing, given that this was also the case at the start. So, full of anticipation, we turn our attention to our last and final evaluation trace, leftAssoc:

    leftAssoc = lowerCodensity ((x >>= f) >>= g)
              -- def of x
              = lowerCodensity ((C (Free (Identity (Pure 20)) >>=) >>= f) >>= g)
              -- (>>=) from Codensity
              = lowerCodensity ((C (\c -> run (C (Free (Identity (Pure 20)) >>=)) (\a -> run (f a) c))) >>= g)
                                          ~~~~~~
              = lowerCodensity ((C (\c -> (Free (Identity (Pure 20)) >>=) (\a -> run (f a) c))) >>= g)
              -- subst x'
              = lowerCodensity ((C (\c -> (x' >>=) (\a -> run (f a) c))) >>= g)
              -- def of f
              = lowerCodensity ((C (\c -> (x' >>=) (\a -> run (C (Free (Identity (Pure (a+1))) >>=)) c))) >>= g)
                                                          ~~~~~~
              = lowerCodensity ((C (\c -> (x' >>=) (\a -> (Free (Identity (Pure (a+1))) >>=) c))) >>= g)
              -- subst f'
              = lowerCodensity ((C (\c -> (x' >>=) (\a -> (f' a >>=) c))) >>= g)
                                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              = lowerCodensity (C (\c2 -> run (C (\c -> (x' >>=) (\a -> (f' a >>=) c))) (\a2 -> run (g a2) c2)))
                                          ~~~~~~
              = lowerCodensity (C (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> run (g a2) c2)))
              -- def of g
              = lowerCodensity (C (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> run (C (Free (Identity (Pure (a2*2))) >>=)) c2)))
                                                                                        ~~~~~~
              = lowerCodensity (C (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> (Free (Identity (Pure (a2*2))) >>=) c2)))
              -- subst g'
              = lowerCodensity (C (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> (g' a2 >>=) c2)))
              -- def lowerCodensity
              = run (C (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> (g' a2 >>=) c2))) return
              = (\c2 -> (\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> (g' a2 >>=) c2)) return
              = ((\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> g' a2 >>= return))
              = ((\c -> (x' >>=) (\a -> (f' a >>=) c)) (\a2 -> g' a2))
              = ((\c -> (x' >>=) (\a -> (f' a >>=) c)) g')
              = (x' >>=) (\a -> (f' a >>=) g')
              = (x' >>=) (\a -> (f' a >>= g')
              = (x' >>= (\a -> (f' a >>= g'))
              = (x' >>= (\x -> (f' x >>= g'))
    

    Finally there we have it, all binds associated to the right, just how we like them!

    4. Aftermath

    If you made it until here, congratulations. Let's summarize what we did:

    1. We demonstrated the problem with wrongly associated (++) in 2a and (>>=) in 2b
    2. We've shown the solution using DList in 3a and Codensity in 3b.
    3. Demonstrated the power of equational reasoning in Haskell :)

    5. Bonus

    Actuall, we can generalize DList from (++) and use (<>) instead to get DMonoid, reordering (<>).

    newtype DMonoid m = DM { unDM :: m -> m }
    
    instance Monoid m => Monoid (DMonoid m) where
      mempty = DM (mempty <>)
      x `mappend` y = DM (unDM x . unDM y)
    
    liftDM :: Monoid m => m -> DMonoid m
    liftDM = DM . (<>)
    
    lowerDM :: Monoid m => DMonoid m -> m
    lowerDM = ($ mempty) . unDM
    

    Then the comparison goes as follows:

    • DMonoid is a (my invention) "monoid transformer", reassociating (<>) to the right
    • Codensity is a monad transformer, reassociating (>>=) to the right

提交回复
热议问题