How to implement generic average function in scala?

后端 未结 2 705
逝去的感伤
逝去的感伤 2021-02-04 03:26

It seems easy problem for any specific kind of Number i.e. Double/Integer but it is hard to write in general case.

implicit def iterebleWithAvg(data:Iterable[Dou         


        
2条回答
  •  慢半拍i
    慢半拍i (楼主)
    2021-02-04 03:53

    Here's the way I define it in my code.

    Instead of using Numeric, I use Fractional, since Fractional defines a division operation (Numeric doesn't necessarily have division). This means that when you call .avg, you will get back the same type you put in, instead of always getting Double.

    I also define it over all GenTraversableOnce collections so that it works on, for example, Iterator.

    class EnrichedAvgFractional[A](self: GenTraversableOnce[A]) {
      def avg(implicit num: Fractional[A]) = {
        val (total, count) = self.toIterator.foldLeft((num.zero, num.zero)) {
          case ((total, count), x) => (num.plus(total, x), num.plus(count, num.one))
        }
        num.div(total, count)
      }
    }
    implicit def enrichAvgFractional[A: Fractional](self: GenTraversableOnce[A]) = new EnrichedAvgFractional(self)
    

    Notice how if we give it a collection of Double, we get back Double and if we give it BigDecimal, we get back BigDecimal. We could even define our own Fractional number type (which I do occasionally), and it will work for that.

    scala> Iterator(1.0, 2.0, 3.0, 4.0, 5.0).avg
    res0: Double = 3.0
    
    scala> Iterator(1.0, 2.0, 3.0, 4.0, 5.0).map(BigDecimal(_)).avg
    res1: scala.math.BigDecimal = 3.0
    

    However, Int is not a kind of Fractional, meaning that it doesn't make sense to get an Int and the result of averaging Ints, so we have to have a special case for Int that converts to a Double.

    class EnrichedAvgInt(self: GenTraversableOnce[Int]) {
      def avg = {
        val (total, count) = self.toIterator.foldLeft(0, 0) {
          case ((total, count), x) => (total + x, count + 1)
        }
        total.toDouble / count
      }
    }
    implicit def enrichAvgInt(self: GenTraversableOnce[Int]) = new EnrichedAvgInt(self)
    

    So averaging Ints gives us a Double:

    scala> Iterator(1, 2, 3, 4, 5).avg
    res2: Double = 3
    

提交回复
热议问题