State Monad, sequences of random numbers and monadic code

孤者浪人 提交于 2019-11-30 01:55:19

The State monad does look kind of confusing at first; let's do as Norman Ramsey suggested, and walk through how to implement from scratch. Warning, this is pretty lengthy!

First, State has two type parameters: the type of the contained state data and the type of the final result of the computation. We'll use stateData and result respectively as type variables for them here. This makes sense if you think about it; the defining characteristic of a State-based computation is that it modifies a state while producing an output.

Less obvious is that the type constructor takes a function from a state to a modified state and result, like so:

newtype State stateData result = State (stateData -> (result, stateData))

So while the monad is called "State", the actual value wrapped by the the monad is that of a State-based computation, not the actual value of the contained state.

Keeping that in mind, we shouldn't be surprised to find that the function runState used to execute a computation in the State monad is actually nothing more than an accessor for the wrapped function itself, and could be defined like this:

runState (State f) = f

So what does it mean when you define a function that returns a State value? Let's ignore for a moment the fact that State is a monad, and just look at the underlying types. First, consider this function (which doesn't actually do anything with the state):

len2State :: String -> State Int Bool
len2State s = return ((length s) == 2)

If you look at the definition of State, we can see that here the stateData type is Int, and the result type is Bool, so the function wrapped by the data constructor must have the type Int -> (Bool, Int). Now, imagine a State-less version of len2State--obviously, it would have type String -> Bool. So how would you go about converting such a function into one returning a value that fits into a State wrapper?

Well, obviously, the converted function will need to take a second parameter, an Int representing the state value. It also needs to return a state value, another Int. Since we're not actually doing anything with the state in this function, let's just do the obvious thing--pass that int right on through. Here's a State-shaped function, defined in terms of the State-less version:

len2 :: String -> Bool
len2 s = ((length s) == 2)

len2State :: String -> (Int -> (Bool, Int))
len2State s i = (len2' s, i)

But that's kind of silly and redundant. Let's generalize the conversion so that we can pass in the result value, and turn anything into a State-like function.

convert :: Bool -> (Int -> (Bool, Int))
convert r d = (r, d)

len2 s = ((length s) == 2)

len2State :: String -> (Int -> (Bool, Int))
len2State s = convert (len2 s)

What if we want a function that changes the state? Obviously we can't build one with convert, since we wrote that to pass the state through. Let's keep it simple, and write a function to overwrite the state with a new value. What kind of type would it need? It'll need an Int for the new state value, and of course will have to return a function stateData -> (result, stateData), because that's what our State wrapper needs. Overwriting the state value doesn't really have a sensible result value outside the State computation, so our result here will just be (), the zero-element tuple that represents "no value" in Haskell.

overwriteState :: Int -> (Int -> ((), Int))
overwriteState newState _ = ((), newState)

That was easy! Now, let's actually do something with that state data. Let's rewrite len2State from above into something more sensible: we'll compare the string length to the current state value.

lenState :: String -> (Int -> (Bool, Int))
lenState s i = ((length s) == i, i)

Can we generalize this into a converter and a State-less function, like we did before? Not quite as easily. Our len function will need to take the state as an argument, but we don't want it to "know about" state. Awkward, indeed. However, we can write a quick helper function that handles everything for us: we'll give it a function that needs to use the state value, and it'll pass the value in and then package everything back up into a State-shaped function leaving len none the wiser.

useState :: (Int -> Bool) -> Int -> (Bool, Int)
useState f d = (f d, d)

len :: String -> Int -> Bool
len s i = (length s) == i

lenState :: String -> (Int -> (Bool, Int))
lenState s = useState (len s)

Now, the tricky part--what if we want to string these functions together? Let's say we want to use lenState on a string, then double the state value if the result is false, then check the string again, and finally return true if either check did. We have all the parts we need for this task, but writing it all out would be a pain. Can we make a function that automatically chains together two functions that each return State-like functions? Sure thing! We just need to make sure it takes as arguments two things: the State function returned by the first function, and a function that takes the prior function's result type as an argument. Let's see how it turns out:

chainStates :: (Int -> (result1, Int)) -> (result1 -> (Int -> (result2, Int))) -> (Int -> (result2, Int))
chainStates prev f d = let (r, d') = prev d
                       in f r d'

All this is doing is applying the first state function to some state data, then applying the second function to the result and the modified state data. Simple, right?

Now, the interesting part: Between chainStates and convert, we should almost be able to turn any combination of State-less functions into a State-enabled function! The only thing we need now is a replacement for useState that returns the state data as its result, so that chainStates can pass it along to the functions that don't know anything about the trick we're pulling on them. Also, we'll use lambdas to accept the result from the previous functions and give them temporary names. Okay, let's make this happen:

extractState :: Int -> (Int, Int)
extractState d = (d, d)

chained :: String -> (Int -> (Bool, Int))
chained str = chainStates  extractState         $ \state1 ->
              let check1 = (len str state1) in
              chainStates (overwriteState (
                  if check1 
                  then state1 
                  else state1 * 2))             $ \ _ ->
              chainStates  extractState         $ \state2 ->
              let check2 = (len str state2) in
              convert (check1 || check2)

And try it out:

> chained "abcd" 2
(True, 4)
> chained "abcd" 3
(False, 6)
> chained "abcd" 4
(True, 4)
> chained "abcdef" 5
(False, 10)

Of course, we can't forget that State is actually a monad that wraps the State-like functions and keeps us away from them, so none of our nifty functions that we've built will help us with the real thing. Or will they? In a shocking twist, it turns out that the real State monad provides all the same functions, under different names:

runState (State s) = s
return r = State (convert r)
(>>=) s f = State (\d -> let (r, d') = (runState s) d in
                         runState (f r) d')
get = State extractState
put d = State (overwriteState d)

Note that >>= is almost identical to chainStates, but there was no good way to define it using chainStates. So, to wrap things up, we can rewrite the final example using the real State:

chained str = get                               >>= \state1 ->
              let check1 = (len str state1) in
              put (if check1 
                  then state1 else state1 * 2)  >>= \ _ ->
              get                               >>= \state2 ->
              let check2 = (len str state2) in
              return (check1 || check2)

Or, all candied up with the equivalent do notation:

chained str = do
        state1 <- get
        let check1 = len str state1
        _ <- put (if check1 then state1 else state1 * 2)
        state2 <- get
        let check2 = (len str state2)
        return (check1 || check2)

First of all, your example is overly complicated because it doesn't need to store the val in the state monad; only the seed is the persistent state. Second, I think you will have better luck if instead of using the standard state monad, you re-implement all of the state monad and its operations yourself, with their types. I think you will learn more this way. Here are a couple of declarations to get you started:

data MyState s a = MyState (s -> (s, b))

get :: Mystate s s
put :: s -> Mystate s ()

Then you can write your own connectives:

unit :: a -> Mystate s a
bind :: Mystate s a -> (a -> Mystate s b) -> Mystate s b

Finally

data Seed = Seed Int
nextVal :: Mystate Seed Bool

As for your trouble desugaring, the do notation you are using is pretty sophisticated. But desugaring is a line-at-a-time mechanical procedure. As near as I can make out, your code should desugar like this (going back to your original types and code, which I disagree with):

 nextVal = get >>= \ Random seed val ->
                      let seed' = updateSeed seed
                          val'  = even seed'
                      in  put (Random seed' val') >>= \ _ -> return val'

In order to make the nesting structure a bit clearer, I've taken major liberties with the indentation.

You've got a couple great responses. What I do when working with the State monad is in my mind replace State s a with s -> (s,a) (after all, that's really what it is).

You then get a type for bind that looks like:

(>>=) :: (s -> (s,a)) ->
         (a -> s -> (s,b)) ->
         (s -> (s,b))

and you see that bind is just a specialized kind of function composition operator, like (.)

I wrote a blog/tutorial on the state monad here. It's probably not particularly good, but helped me grok things a little better by writing it.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!