Partition a collection into “k” close-to-equal pieces (Scala, but language agnostic)

前端 未结 6 2091
南方客
南方客 2021-02-12 15:44

Defined before this block of code:

  • dataset can be a Vector or List
  • numberOfSlices is an Int
相关标签:
6条回答
  • 2021-02-12 15:58

    If the behavior of xs.grouped(xs.size / n) doesn't work for you, it's pretty easy to define exactly what you want. The quotient is the size of the smaller pieces, and the remainder is the number of the bigger pieces:

    def cut[A](xs: Seq[A], n: Int) = {
      val (quot, rem) = (xs.size / n, xs.size % n)
      val (smaller, bigger) = xs.splitAt(xs.size - rem * (quot + 1))
      smaller.grouped(quot) ++ bigger.grouped(quot + 1)
    }
    
    0 讨论(0)
  • 2021-02-12 16:01

    I'd approach it this way: Given n elements and m partitions (n>m), either n mod m == 0 in which case, each partition will have n/m elements, or n mod m = y, in which case you'll have each partition with n/m elements and you have to distribute y over some m.

    You'll have y slots with n/m+1 elements and (m-y) slots with n/m. How you distribute them is your choice.

    0 讨论(0)
  • 2021-02-12 16:04

    Here's a one-liner that does the job for me, using the familiar Scala trick of a recursive function that returns a Stream. Notice the use of (x+k/2)/k to round the chunk sizes, intercalating the smaller and larger chunks in the final list, all with sizes with at most one element of difference. If you round up instead, with (x+k-1)/k, you move the smaller blocks to the end, and x/k moves them to the beginning.

    def k_folds(k: Int, vv: Seq[Int]): Stream[Seq[Int]] =
        if (k > 1)
            vv.take((vv.size+k/2)/k) +: k_folds(k-1, vv.drop((vv.size+k/2)/k))
        else
            Stream(vv)
    

    Demo:

    scala> val indices = scala.util.Random.shuffle(1 to 39)
    
    scala> for (ff <- k_folds(7, indices)) println(ff)
    Vector(29, 8, 24, 14, 22, 2)
    Vector(28, 36, 27, 7, 25, 4)
    Vector(6, 26, 17, 13, 23)
    Vector(3, 35, 34, 9, 37, 32)
    Vector(33, 20, 31, 11, 16)
    Vector(19, 30, 21, 39, 5, 15)
    Vector(1, 38, 18, 10, 12)
    
    scala> for (ff <- k_folds(7, indices)) println(ff.size)
    6
    6
    5
    6
    5
    6
    5
    
    scala> for (ff <- indices.grouped((indices.size+7-1)/7)) println(ff)
    Vector(29, 8, 24, 14, 22, 2)
    Vector(28, 36, 27, 7, 25, 4)
    Vector(6, 26, 17, 13, 23, 3)
    Vector(35, 34, 9, 37, 32, 33)
    Vector(20, 31, 11, 16, 19, 30)
    Vector(21, 39, 5, 15, 1, 38)
    Vector(18, 10, 12)
    
    scala> for (ff <- indices.grouped((indices.size+7-1)/7)) println(ff.size)
    6
    6
    6
    6
    6
    6
    3
    

    Notice how grouped does not try to even out the size of all the sub-lists.

    0 讨论(0)
  • 2021-02-12 16:04

    Here is my take on the problem:

      def partition[T](items: Seq[T], partitionsCount: Int): List[Seq[T]] = {
        val minPartitionSize = items.size / partitionsCount
        val extraItemsCount = items.size % partitionsCount
    
        def loop(unpartitioned: Seq[T], acc: List[Seq[T]], extra: Int): List[Seq[T]] =
          if (unpartitioned.nonEmpty) {
            val (splitIndex, newExtra) = if (extra > 0) (minPartitionSize + 1, extra - 1) else (minPartitionSize, extra)
            val (newPartition, remaining) = unpartitioned.splitAt(splitIndex)
            loop(remaining, newPartition :: acc, newExtra)
          } else acc
    
        loop(items, List.empty, extraItemsCount).reverse
      }
    

    It's more verbose than some of the other solutions but hopefully more clear as well. reverse is only necessary if you want the order to be preserved.

    0 讨论(0)
  • 2021-02-12 16:08

    The typical "optimal" partition calculates an exact fractional length after cutting and then rounds to find the actual number to take:

    def cut[A](xs: Seq[A], n: Int):Vector[Seq[A]] = {
      val m = xs.length
      val targets = (0 to n).map{x => math.round((x.toDouble*m)/n).toInt}
      def snip(xs: Seq[A], ns: Seq[Int], got: Vector[Seq[A]]): Vector[Seq[A]] = {
        if (ns.length<2) got
        else {
          val (i,j) = (ns.head, ns.tail.head)
          snip(xs.drop(j-i), ns.tail, got :+ xs.take(j-i))
        }
      }
      snip(xs, targets, Vector.empty)
    }
    

    This way your longer and shorter blocks will be interspersed, which is often more desirable for evenness:

    scala> cut(List(1,2,3,4,5,6,7,8,9,10),4)
    res5: Vector[Seq[Int]] = 
      Vector(List(1, 2, 3), List(4, 5), List(6, 7, 8), List(9, 10))
    

    You can even cut more times than you have elements:

    scala> cut(List(1,2,3),5)
    res6: Vector[Seq[Int]] = 
      Vector(List(1), List(), List(2), List(), List(3))
    
    0 讨论(0)
  • 2021-02-12 16:09

    As Kaito mentions grouped is exactly what you are looking for. But if you just want to know how to implement such a method, there are many ways ;-). You could for example do it like this:

    def grouped[A](xs: List[A], size: Int) = {
      def grouped[A](xs: List[A], size: Int, result: List[List[A]]): List[List[A]] = {
        if(xs.isEmpty) {
          result
        } else {
          val (slice, rest) = xs.splitAt(size)
          grouped(rest, size, result :+ slice)
        }
      }
      grouped(xs, size, Nil)
    }
    
    0 讨论(0)
提交回复
热议问题