I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad ha
The state monad
boils down to this function from one state
to another state
(plus A
):
type StatefulComputation[S, +A] = S => (A, S)
The implementation mentioned by Tony in that blog post "capture" that function into run
of the case class
:
case class State[S, A](run: S => (A, S))
The flatmap
implementation to bind
a state
to another state
is calling 2 different run
s:
// the `run` on the actual `state`
val (a: A, nextState: S) = run(s)
// the `run` on the bound `state`
f(a).run(nextState)
EDIT Example of flatmap
between 2 State
Considering a function that simply call .head
to a List
to get A
, and .tail
for the next state S
// stateful computation: `S => (A, S)` where `S` is `List[A]`
def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)
A simple binding of 2 State(head[Int])
:
// flatmap example
val result = for {
a <- State(head[Int])
b <- State(head[Int])
} yield Map('a' -> a,
'b' -> b)
The expect behaviour of the for-comprehension
is to "extract" the first element of a list into a
and the second one in b
. The resulting state S
would be the remaining tail of the run list:
scala> result.run(List(1, 2, 3, 4, 5))
(Map(a -> 1, b -> 2),List(3, 4, 5))
How? Calling the "stateful computation" head[Int]
that is in run
on some state s
:
s => run(s)
That gives the head
(A
) and the tail
(B
) of the list. Now we need to pass the tail
to the next State(head[Int])
f(a).run(t)
Where f
is in the flatmap
signature:
def flatMap[B](f: A => State[S, B]): State[S, B]
Maybe to better understand what is f
in this example, we should de-sugar the for-comprehension
to:
val result = State(head[Int]).flatMap {
a => State(head[Int]).map {
b => Map('a' -> a, 'b' -> b)
}
}
With f(a)
we pass a
into the function and with run(t)
we pass the modified state.