I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad ha
I have accepted @AlexyRaga's answer to my question. I think @Filippo's answer was very good as well and, in fact, gave me some additional food for thought. Thanks to both of you.
I think the conceptual difficulty I was having was really mostly to do with 'what does the run
method 'mean'. That is, what is its purpose and result. I was looking at it as a 'transition' function (from one state to the next). And, after a fashion, that is what it does. However, it doesn't transition from a given (this
) state to the next state. Instead, it takes an initial State
and returns the (this
) state's value and a new 'current' state (not the next state in the state-transition sequence).
That is why the flatMap
method is implemented the way it is. When you generate a new State
then you need the current value/state pair from it based on the passed-in initial state which can then be wrapped in a new State
object as a function. You are not really transitioning to a new state. Just re-wrapping the generated state in a new State
object.
I was too steeped in traditional state machines to see what was going on here.
Thank, again, everyone.
The state monad
boils down to this function from one state
to another state
(plus A
):
type StatefulComputation[S, +A] = S => (A, S)
The implementation mentioned by Tony in that blog post "capture" that function into run
of the case class
:
case class State[S, A](run: S => (A, S))
The flatmap
implementation to bind
a state
to another state
is calling 2 different run
s:
// the `run` on the actual `state`
val (a: A, nextState: S) = run(s)
// the `run` on the bound `state`
f(a).run(nextState)
EDIT Example of flatmap
between 2 State
Considering a function that simply call .head
to a List
to get A
, and .tail
for the next state S
// stateful computation: `S => (A, S)` where `S` is `List[A]`
def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)
A simple binding of 2 State(head[Int])
:
// flatmap example
val result = for {
a <- State(head[Int])
b <- State(head[Int])
} yield Map('a' -> a,
'b' -> b)
The expect behaviour of the for-comprehension
is to "extract" the first element of a list into a
and the second one in b
. The resulting state S
would be the remaining tail of the run list:
scala> result.run(List(1, 2, 3, 4, 5))
(Map(a -> 1, b -> 2),List(3, 4, 5))
How? Calling the "stateful computation" head[Int]
that is in run
on some state s
:
s => run(s)
That gives the head
(A
) and the tail
(B
) of the list. Now we need to pass the tail
to the next State(head[Int])
f(a).run(t)
Where f
is in the flatmap
signature:
def flatMap[B](f: A => State[S, B]): State[S, B]
Maybe to better understand what is f
in this example, we should de-sugar the for-comprehension
to:
val result = State(head[Int]).flatMap {
a => State(head[Int]).map {
b => Map('a' -> a, 'b' -> b)
}
}
With f(a)
we pass a
into the function and with run(t)
we pass the modified state.
To understand the "second run" let's analyse it "backwards".
The signature def flatMap[B](f: A => State[S, B]): State[S, B]
suggests that we need to run a function f
and return its result.
To execute function f
we need to give it an A
. Where do we get one?
Well, we have run
that can give us A
out of S
, so we need an S
.
Because of that we do: s => val (a, t) = run(s) ...
.
We read it as "given an S
execute the run
function which produces us A
and a new S
. And this is our "first" run.
Now we have an A
and we can execute f
. That's what we wanted and f(a)
gives us a new State[S, B]
.
If we do that then we have a function which takes S
and returns Stats[S, B]
:
(s: S) =>
val (a, t) = run(s)
f(a) //State[S, B]
But function S => State[S, B]
isn't what we want to return! We want to return just State[S, B]
.
How do we do that? We can wrap this function into State
:
State(s => ... f(a))
But it doesn't work because State
takes S => (B, S)
, not S => State[B, S]
.
So we need to get (B, S)
out of State[B, S]
.
We do it by just calling its run
method and providing it with the state we just produced on the previous step!
And it is our "second" run.
So as a result we have the following transformation performed by a flatMap
:
s => // when a state is provided
val (a, t) = run(s) // produce an `A` and a new state value
val resState = f(a) // produce a new `State[S, B]`
resState.run(t) // return `(S, B)`
This gives us S => (S, B)
and we just wrap it with the State
constructor.
Another way of looking at these "two runs" is:
first - we transform the state ourselves with "our" run
function
second - we pass that transformed state to the function f
and let it do its own transformation.
So we kind of "chaining" state transformations one after another. And that's exactly what monads do: they provide us with the ability to schedule computation sequentially.