Scala and State Monad

前端 未结 3 2556
陌清茗
陌清茗 2021-02-20 13:36

I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad ha

3条回答
  •  挽巷
    挽巷 (楼主)
    2021-02-20 13:58

    The state monad boils down to this function from one state to another state (plus A):

    type StatefulComputation[S, +A] = S => (A, S)
    

    The implementation mentioned by Tony in that blog post "capture" that function into run of the case class:

    case class State[S, A](run: S => (A, S))
    

    The flatmap implementation to bind a state to another state is calling 2 different runs:

        // the `run` on the actual `state`
        val (a: A, nextState: S) = run(s)
    
        // the `run` on the bound `state`
        f(a).run(nextState)
    

    EDIT Example of flatmap between 2 State

    Considering a function that simply call .head to a List to get A, and .tail for the next state S

    // stateful computation: `S => (A, S)` where `S` is `List[A]`
    def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)
    

    A simple binding of 2 State(head[Int]):

    // flatmap example
    val result = for {
      a <- State(head[Int])
      b <- State(head[Int])
    } yield Map('a' -> a,
                'b' -> b)
    

    The expect behaviour of the for-comprehension is to "extract" the first element of a list into a and the second one in b. The resulting state S would be the remaining tail of the run list:

    scala> result.run(List(1, 2, 3, 4, 5))
    (Map(a -> 1, b -> 2),List(3, 4, 5))
    

    How? Calling the "stateful computation" head[Int] that is in run on some state s:

    s => run(s)
    

    That gives the head (A) and the tail (B) of the list. Now we need to pass the tail to the next State(head[Int])

    f(a).run(t)
    

    Where f is in the flatmap signature:

    def flatMap[B](f: A => State[S, B]): State[S, B]
    

    Maybe to better understand what is f in this example, we should de-sugar the for-comprehension to:

    val result = State(head[Int]).flatMap {
      a => State(head[Int]).map {
        b => Map('a' -> a, 'b' -> b)
      }
    }
    

    With f(a) we pass a into the function and with run(t) we pass the modified state.

提交回复
热议问题