Monadic fold with State monad in constant space (heap and stack)?

后端 未结 2 1603
小鲜肉
小鲜肉 2021-02-20 04:35

Is it possible to perform a fold in the State monad in constant stack and heap space? Or is a different functional technique a better fit to my problem?

The next section

相关标签:
2条回答
  • 2021-02-20 05:13

    Our real issue is the heap used by the unexecuted State mobits.

    No, it is not. The real issue is that the collection doesn't fit in memory and that foldLeftM and foldRightM force the entire collection. A side effect of the impure solution is that you are freeing memory as you go. In the "purely functional" solution, you're not doing that anywhere.

    Your use of Iterable ignores a crucial detail: what kind of collection col actually is, how its elements are created and how they are expected to be discarded. And so, necessarily, does foldLeftM on Iterable. It is likely too strict, and you are forcing the entire collection into memory. For example, if it is a Stream, then as long as you are holding on to col all the elements forced so far will be in memory. If it's some other kind of lazy Iterable that doesn't memoize its elements, then the fold is still too strict.

    I tried your first example with an EphemeralStream did not see any significant heap pressure, even though it will clearly have the same "unexecuted State mobits". The difference is that an EphemeralStream's elements are weakly referenced and its foldRight doesn't force the entire stream.

    I suspect that if you used Foldable.foldr, then you would not see the problematic behaviour since it folds with a function that is lazy in its second argument. When you call the fold, you want it to return a suspension that looks something like this immediately:

    Suspend(() => head |+| tail.foldRightM(...))
    

    When the trampoline resumes the first suspension and runs up to the next suspension, all of the allocations between suspensions will become available to be freed by the garbage collector.

    Try the following:

    def foldM[M[_]:Monad,A,B](a: A, bs: Iterable[B])(f: (A, B) => M[A]): M[A] =
      if (bs.isEmpty) Monad[M].point(a)
      else Monad[M].bind(f(a, bs.head))(fax => foldM(fax, bs.tail)(f))
    
    val MS = StateT.stateTMonadState[Int, Trampoline]
    import MS._
    
    foldM[M,R,Int](Monoid[R].zero, col) {
      (x, r) => modify(_ + 1) map (_ => Monoid[R].append(x, r))
    } run 0 run
    

    This will run in constant heap for a trampolined monad M, but will overflow the stack for a non-trampolined monad.

    But the real problem is that Iterable is not a good abstraction for data that are too large to fit in memory. Sure, you can write an imperative side-effecty program where you explicitly discard elements after each iteration or use a lazy right fold. That works well until you want to compose that program with another one. And I'm assuming that the whole reason you're investigating doing this in a State monad to begin with is to gain compositionality.

    So what can you do? Here are some options:

    1. Make use of Reducer, Monoid, and composition thereof, then run in an imperative explicitly-freeing loop (or a trampolined lazy right fold) as the last step, after which composition is not possible or expected.
    2. Use Iteratee composition and monadic Enumerators to feed them.
    3. Write compositional stream transducers with Scalaz-Stream.

    The last of these options is the one that I would use and recommend in the general case.

    0 讨论(0)
  • 2021-02-20 05:15

    Using State, or any similar monad, isn't a good approach to the problem. Using State is condemned to blow the stack/heap on large collections. Consider a value of x: State[A,B] constructed from a large collection (for example by folding over it). Then x can be evaluated on different values of the initial state A, yielding different results. So x needs to retain all information contained in the collection. An in pure settings, x can't forget some information not to blow stack/heap, so anything that is computed remains in memory until the whole monadic value is freed, which happens only after the result is evaluated. So the memory consumption of x is proportional to the size of the collection.

    I believe a fitting approach to this problem is to use functional iteratees/pipes/conduits. This concept (referred to under these three names) was invented to process large collections of data with constant memory consumption, and to describe such processes using simple combinator.

    I tried to use Scalaz' Iteratees, but it seems this part isn't mature yet, it suffers from stack overflows just as State does (or perhaps I'm not using it right; the code is available here, if anybody is interested).

    However, it was simple using my (still a bit experimental) scala-conduit library (disclaimer: I'm the author):

    import conduit._
    import conduit.Pipe._
    
    object Run extends App {
      // Define a sampling function as a sink: It consumes
      // data of type `A` and produces a vector of samples.
      def sampleI[A](k: Int): Sink[A, Vector[A]] =
        sampleI[A](k, 0, Vector())
    
      // Create a sampling sink with a given state. It requests
      // a value from the upstream conduit. If there is one,
      // update the state and continue (the first argument to `requestF`).
      // If not, return the current sample (the second argument).
      // The `Finalizer` part isn't important for our problem.
      private def sampleI[A](k: Int, n: Int, sample: Vector[A]):
                      Sink[A, Vector[A]] =
        requestF((x: A) => sampleI(k, n + 1, algorithmR(k, n + 1, sample, x)),
                 (_: Any) => sample)(Finalizer.empty)
    
    
      // The sampling algorithm copied from the question.
      val rand = new scala.util.Random()
    
      def algorithmR[A](k: Int, n: Int, sample: Vector[A], x: A): Vector[A] = {
        if (sample.size < k) {
          sample :+ x // must keep first k elements
        } else {
          val r = rand.nextInt(n) + 1 // for simplicity, rand is global/stateful
          if (r <= k)
            sample.updated(r - 1, x) // sample is 0-index
          else
            sample
        }
      }
    
      // Construct an iterable of all `short` values, pipe it into our sampling
      // funcition, and run the combined pipe.
      {
        print(runPipe(Util.fromIterable(Short.MinValue to Short.MaxValue) >->
              sampleI(10)))
      }
    }
    

    Update: It'd be possible to solve the problem using State, but we need to implement a custom fold specifically for State that knows how to do it constant space:

    import scala.collection._
    import scala.language.higherKinds
    import scalaz._
    import Scalaz._
    import scalaz.std.iterable._
    
    object Run extends App {
      // Folds in a state monad over a foldable
      def stateFold[F[_],E,S,A](xs: F[E],
                                f: (A, E) => State[S,A],
                                z: A)(implicit F: Foldable[F]): State[S,A] =
        State[S,A]((s: S) => F.foldLeft[E,(S,A)](xs, (s, z))((p, x) => f(p._2, x)(p._1)))
    
    
      // Sample a lazy collection view
      def sampleS[F[_],A](k: Int, xs: F[A])(implicit F: Foldable[F]):
                      State[Int,Vector[A]] =
        stateFold[F,A,Int,Vector[A]](xs, update(k), Vector())
    
      // update using State monad
      def update[A](k: Int) = {
        (acc: Vector[A], x: A) => State[Int, Vector[A]] {
            n => (n + 1, algorithmR(k, n + 1, acc, x)) // algR same as impure solution
        }
      }
    
      def algorithmR[A](k: Int, n: Int, sample: Vector[A], x: A): Vector[A] = ...
    
      {
        print(sampleS(10, (Short.MinValue to Short.MaxValue)).eval(0))
      }
    }
    
    0 讨论(0)
提交回复
热议问题