Understanding filterM

后端 未结 2 1268
情歌与酒
情歌与酒 2021-02-18 23:57

Consider

filterM (\\x -> [True, False]) [1, 2, 3]

I just cannot understand the magic that Haskell does with this filterM use ca

2条回答
  •  情书的邮戳
    2021-02-19 00:55

    The list monad [] models non-determinism: a list of values [a] represents a number of different possibilities for the value of a.

    When you see a statement like flg <- p x in the list monad, flg will take on each value of p x in turn, i.e. True and then False in this case. The rest of the body of filterM is then executed twice, once for each value of flg.

    To see how this happens in more detail, you need to understand the desugaring of do notation and the implementation of the (>>=) operator for the list monad.

    do notation gets desugared line-by-line into calls to the (>>=) operator. For example the body of the non-empty filterM case turns into

    p x >>= \flg -> (filterM p xs >>= \ys -> return (if flg then x:ys else ys))
    

    This is completely mechanical as it's in essence just replacing flg <- before the expression with >>= \flg -> after the expression. In reality pattern-matching makes this a little more complicated, but not much.

    Next is the actual implementation of (>>=), which is a member of the Monad type class and has a different implementation for each instance. For [], the type is:

    (>>=) :: [a] -> (a -> [b]) -> [b]
    

    and the implementation is something like

    [] >>= f = []
    (x:xs) >>= f = f x ++ (xs >>= f)
    

    So the loop happens in the body of (>>=). This is all in a library, no compiler magic beyond the desugaring of the do notation.

    An equivalent definition for (>>=) is

     xs >>= f = concat (map f xs)
    

    which may also help you see what's happening.

    The same thing then happens for the recursive call to filterM: for each value of flg, the recursive call is made and produces a list of results, and the final return statement is executed for each element ys in this list of result.

    This "fan-out" on each recursive call leads to 2^3 = 8 elements in the final result of filterM (\x -> [True, False]) [1, 2, 3].

提交回复
热议问题