Constructing efficient monad instances on `Set` (and other containers with constraints) using the continuation monad

后端 未结 4 1130
别那么骄傲
别那么骄傲 2020-12-08 09:46

Set, similarly to [] has a perfectly defined monadic operations. The problem is that they require that the values satisfy Ord constrai

相关标签:
4条回答
  • 2020-12-08 10:27

    Monads are one particular way of structuring and sequencing computations. The bind of a monad cannot magically restructure your computation so as to happen in a more efficient way. There are two problems with the way you structure your computation.

    1. When evaluating stepN 20 0, the result of step 0 will be computed 20 times. This is because each step of the computation produces 0 as one alternative, which is then fed to the next step, which also produces 0 as alternative, and so on...

      Perhaps a bit of memoization here can help.

    2. A much bigger problem is the effect of ContT on the structure of your computation. With a bit of equational reasoning, expanding out the result of replicate 20 step, the definition of foldrM and simplifying as many times as necessary, we can see that stepN 20 0 is equivalent to:

      (...(return 0 >>= step) >>= step) >>= step) >>= ...)
      

      All parentheses of this expression associate to the left. That's great, because it means that the RHS of each occurrence of (>>=) is an elementary computation, namely step, rather than a composed one. However, zooming in on the definition of (>>=) for ContT,

      m >>= k = ContT $ \c -> runContT m (\a -> runContT (k a) c)
      

      we see that when evaluating a chain of (>>=) associating to the left, each bind will push a new computation onto the current continuation c. To illustrate what is going on, we can use again a bit of equational reasoning, expanding out this definition for (>>=) and the definition for runContT, and simplifying, yielding:

      setReturn 0 `setBind`
          (\x1 -> step x1 `setBind`
              (\x2 -> step x2 `setBind` (\x3 -> ...)...)
      

      Now, for each occurrence of setBind, let's ask ourselves what the RHS argument is. For the leftmost occurrence, the RHS argument is the whole rest of the computation after setReturn 0. For the second occurrence, it's everything after step x1, etc. Let's zoom in to the definition of setBind:

      setBind set f = foldl' (\s -> union s . f) empty set
      

      Here f represents all the rest of the computation, everything on the right hand side of an occurrence of setBind. That means that at each step, we are capturing the rest of the computation as f, and applying f as many times as there are elements in set. The computations are not elementary as before, but rather composed, and these computations will be duplicated many times.

    The crux of the problem is that the ContT monad transformer is transforming the initial structure of the computation, which you meant as a left associative chain of setBind's, into a computation with a different structure, ie a right associative chain. This is after all perfectly fine, because one of the monad laws says that, for every m, f and g we have

    (m >>= f) >>= g = m >>= (\x -> f x >>= g)
    

    However, the monad laws do not impose that the complexity remain the same on each side of the equations of each law. And indeed, in this case, the left associative way of structuring this computation is a lot more efficient. The left associative chain of setBind's evaluates in no time, because only elementary subcomputations are duplicated.

    It turns out that other solutions shoehorning Set into a monad also suffer from the same problem. In particular, the set-monad package, yields similar runtimes. The reason being, that it too, rewrites left associative expressions into right associative ones.

    I think you have put the finger on a very important yet rather subtle problem with insisting that Set obeys a Monad interface. And I don't think it can be solved. The problem is that the type of the bind of a monad needs to be

    (>>=) :: m a -> (a -> m b) -> m b
    

    ie no class constraint allowed on either a or b. That means that we cannot nest binds on the left, without first invoking the monad laws to rewrite into a right associative chain. Here's why: given (m >>= f) >>= g, the type of the computation (m >>= f) is of the form m b. A value of the computation (m >>= f) is of type b. But because we can't hang any class constraint onto the type variable b, we can't know that the value we got satisfies an Ord constraint, and therefore cannot use this value as the element of a set on which we want to be able to compute union's.

    0 讨论(0)
  • 2020-12-08 10:31

    I found out another possibility, based on GHC's ConstraintKinds extension. The idea is to redefine Monad so that it includes a parametric constraint on allowed values:

    {-# LANGUAGE ConstraintKinds #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE RebindableSyntax #-}
    
    import qualified Data.Foldable as F
    import qualified Data.Set as S
    import Prelude hiding (Monad(..), Functor(..))
    
    class CFunctor m where
        -- Each instance defines a constraint it valust must satisfy:
        type Constraint m a
        -- The default is no constraints.
        type Constraint m a = ()
        fmap   :: (Constraint m a, Constraint m b) => (a -> b) -> (m a -> m b)
    class CFunctor m => CMonad (m :: * -> *) where
        return :: (Constraint m a) => a -> m a
        (>>=)  :: (Constraint m a, Constraint m b) => m a -> (a -> m b) -> m b
        fail   :: String -> m a
        fail   = error
    
    -- [] instance
    instance CFunctor [] where
        fmap = map
    instance CMonad [] where
        return  = (: [])
        (>>=)   = flip concatMap
    
    -- Set instance
    instance CFunctor S.Set where
        -- Sets need Ord.
        type Constraint S.Set a = Ord a
        fmap = S.map
    instance CMonad S.Set where
        return  = S.singleton
        (>>=)   = flip F.foldMap
    
    -- Example:
    
    -- prints fromList [3,4,5]
    main = print $ do
        x <- S.fromList [1,2]
        y <- S.fromList [2,3]
        return $ x + y
    

    (The problem with this approach is in the case the monadic values are functions, such as m (a -> b), because they can't satisfy constraints like Ord (a -> b). So one can't use combinators like <*> (or ap) for this constrained Set monad.)

    0 讨论(0)
  • 2020-12-08 10:32

    Recently on Haskell Cafe Oleg gave an example how to implement the Set monad efficiently. Quoting:

    ... And yet, the efficient genuine Set monad is possible.

    ... Enclosed is the efficient genuine Set monad. I wrote it in direct style (it seems to be faster, anyway). The key is to use the optimized choose function when we can.

      {-# LANGUAGE GADTs, TypeSynonymInstances, FlexibleInstances #-}
    
      module SetMonadOpt where
    
      import qualified Data.Set as S
      import Control.Monad
    
      data SetMonad a where
          SMOrd :: Ord a => S.Set a -> SetMonad a
          SMAny :: [a] -> SetMonad a
    
      instance Monad SetMonad where
          return x = SMAny [x]
    
          m >>= f = collect . map f $ toList m
    
      toList :: SetMonad a -> [a]
      toList (SMOrd x) = S.toList x
      toList (SMAny x) = x
    
      collect :: [SetMonad a] -> SetMonad a
      collect []  = SMAny []
      collect [x] = x
      collect ((SMOrd x):t) = case collect t of
                               SMOrd y -> SMOrd (S.union x y)
                               SMAny y -> SMOrd (S.union x (S.fromList y))
      collect ((SMAny x):t) = case collect t of
                               SMOrd y -> SMOrd (S.union y (S.fromList x))
                               SMAny y -> SMAny (x ++ y)
    
      runSet :: Ord a => SetMonad a -> S.Set a
      runSet (SMOrd x) = x
      runSet (SMAny x) = S.fromList x
    
      instance MonadPlus SetMonad where
          mzero = SMAny []
          mplus (SMAny x) (SMAny y) = SMAny (x ++ y)
          mplus (SMAny x) (SMOrd y) = SMOrd (S.union y (S.fromList x))
          mplus (SMOrd x) (SMAny y) = SMOrd (S.union x (S.fromList y))
          mplus (SMOrd x) (SMOrd y) = SMOrd (S.union x y)
    
      choose :: MonadPlus m => [a] -> m a
      choose = msum . map return
    
    
      test1 = runSet (do
        n1 <- choose [1..5]
        n2 <- choose [1..5]
        let n = n1 + n2
        guard $ n < 7
        return n)
      -- fromList [2,3,4,5,6]
    
      -- Values to choose from might be higher-order or actions
      test1' = runSet (do
        n1 <- choose . map return $ [1..5]
        n2 <- choose . map return $ [1..5]
        n  <- liftM2 (+) n1 n2
        guard $ n < 7
        return n)
      -- fromList [2,3,4,5,6]
    
      test2 = runSet (do
        i <- choose [1..10]
        j <- choose [1..10]
        k <- choose [1..10]
        guard $ i*i + j*j == k * k
        return (i,j,k))
      -- fromList [(3,4,5),(4,3,5),(6,8,10),(8,6,10)]
    
      test3 = runSet (do
        i <- choose [1..10]
        j <- choose [1..10]
        k <- choose [1..10]
        guard $ i*i + j*j == k * k
        return k)
      -- fromList [5,10]
    
      -- Test by Petr Pudlak
    
      -- First, general, unoptimal case
      step :: (MonadPlus m) => Int -> m Int
      step i = choose [i, i + 1]
    
      -- repeated application of step on 0:
      stepN :: Int -> S.Set Int
      stepN = runSet . f
        where
        f 0 = return 0
        f n = f (n-1) >>= step
    
      -- it works, but clearly exponential
      {-
      *SetMonad> stepN 14
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
      (0.09 secs, 31465384 bytes)
      *SetMonad> stepN 15
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      (0.18 secs, 62421208 bytes)
      *SetMonad> stepN 16
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
      (0.35 secs, 124876704 bytes)
      -}
    
      -- And now the optimization
      chooseOrd :: Ord a => [a] -> SetMonad a
      chooseOrd x = SMOrd (S.fromList x)
    
      stepOpt :: Int -> SetMonad Int
      stepOpt i = chooseOrd [i, i + 1]
    
      -- repeated application of step on 0:
      stepNOpt :: Int -> S.Set Int
      stepNOpt = runSet . f
        where
        f 0 = return 0
        f n = f (n-1) >>= stepOpt
    
      {-
      stepNOpt 14
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
      (0.00 secs, 515792 bytes)
      stepNOpt 15
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      (0.00 secs, 515680 bytes)
      stepNOpt 16
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
      (0.00 secs, 515656 bytes)
    
      stepNOpt 30
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
      (0.00 secs, 1068856 bytes)
      -}
    
    0 讨论(0)
  • 2020-12-08 10:42

    I don't think your performance problems in this case are due to the use of Cont

    step' :: Int -> Set Int
    step' i = fromList [i,i + 1]
    
    foldrM' f z0 xs = Prelude.foldl f' setReturn xs z0
      where f' k x z = f x z `setBind` k
    
    stepN' :: Int -> Int -> Set Int
    stepN' times start = foldrM' ($) start (replicate times step')
    

    gets similar performance to the Cont based implementation but occurs entirely in the Set "restricted monad"

    I am not sure if I believe your claim about Glivenko's theorem leading to exponential increase in (normalized) proof size--at least in the Call-By-Need context. That is because we can arbitrarily reuse subproofs (and our logic is second order, we need only a single proof of forall a. ~~(a \/ ~a)). Proofs are not trees, they are graphs (sharing).

    In general, you are likely to see performance costs from Cont wrapping Set but they can usually be avoided via

    smash :: (Ord r, Ord k) => SetM r r -> SetM k r
    smash = fromSet . toSet
    
    0 讨论(0)
提交回复
热议问题