Newbie Scala question about simple math array operations

前端 未结 4 1376
自闭症患者
自闭症患者 2021-02-14 22:17

Newbie Scala Question:

Say I want to do this [Java code] in Scala:

public static double[] abs(double[] r, double[] im) {
  double t[] = new double[r.leng         


        
4条回答
  •  一向
    一向 (楼主)
    2021-02-14 23:08

    Doing generic/performant primitives in scala actually involves two related mechanisms which scala uses to avoid boxing/unboxing (e.g. wrapping an int in a java.lang.Integer and vice versa):

    • @specialize type annotations
    • Using Manifest with arrays

    specialize is an annotation that tells the Java compiler to create "primitive" versions of code (akin to C++ templates, so I am told). Check out the type declaration of Tuple2 (which is specialized) compared with List (which isn't). It was added in 2.8 and means that, for example code like CC[Int].map(f : Int => Int) is executed without ever boxing any ints (assuming CC is specialized, of course!).

    Manifests are a way of doing reified types in scala (which is limited by the JVM's type erasure). This is particularly useful when you want to have a method genericized on some type T and then create an array of T (i.e. T[]) within the method. In Java this is not possible because new T[] is illegal. In scala this is possible using Manifests. In particular, and in this case it allows us to construct a primitive T-array, like double[] or int[]. (This is awesome, in case you were wondering)

    Boxing is so important from a performance perspective because it creates garbage, unless all of your ints are < 127. It also, obviously, adds a level of indirection in terms of extra process steps/method calls etc. But consider that you probably don't give a hoot unless you are absolutely positively sure that you definitely do (i.e. most code does not need such micro-optimization)


    So, back to the question: in order to do this with no boxing/unboxing, you must use Array (List is not specialized yet, and would be more object-hungry anyway, even if it were!). The zipped function on a pair of collections will return a collection of Tuple2s (which will not require boxing, as this is specialized).

    In order to do this generically (i.e. across various numeric types) you must require a context bound on your generic parameter that it is Numeric and that a Manifest can be found (required for array creation). So I started along the lines of...

    def abs[T : Numeric : Manifest](rs : Array[T], ims : Array[T]) : Array[T] = {
        import math._
        val num = implicitly[Numeric[T]]
        (rs, ims).zipped.map { (r, i) => sqrt(num.plus(num.times(r,r), num.times(i,i))) }
        //                               ^^^^ no SQRT function for Numeric
    }
    

    ...but it doesn't quite work. The reason is that a "generic" Numeric value does not have an operation like sqrt -> so you could only do this at the point of knowing you had a Double. For example:

    scala> def almostAbs[T : Manifest : Numeric](rs : Array[T], ims : Array[T]) : Array[T] = {
     | import math._
     | val num = implicitly[Numeric[T]]
     | (rs, ims).zipped.map { (r, i) => num.plus(num.times(r,r), num.times(i,i)) }
     | }
    almostAbs: [T](rs: Array[T],ims: Array[T])(implicit evidence$1: Manifest[T],implicit     evidence$2: Numeric[T])Array[T]
    

    Excellent - now see this purely generic method do some stuff!

    scala> val rs = Array(1.2, 3.4, 5.6); val is = Array(6.5, 4.3, 2.1)
    rs: Array[Double] = Array(1.2, 3.4, 5.6)
    is: Array[Double] = Array(6.5, 4.3, 2.1)
    
    scala> almostAbs(rs, is)
    res0: Array[Double] = Array(43.69, 30.049999999999997, 35.769999999999996)
    

    Now we can sqrt the result, because we have a Array[Double]

    scala> res0.map(math.sqrt(_))
    res1: Array[Double] = Array(6.609841147864296, 5.481788029466298, 5.980802621722272)
    

    And to prove that this would work even with another Numeric type:

    scala> import math._
    import math._
    scala> val rs = Array(BigDecimal(1.2), BigDecimal(3.4), BigDecimal(5.6)); val is =     Array(BigDecimal(6.5), BigDecimal(4.3), BigDecimal(2.1))
    rs: Array[scala.math.BigDecimal] = Array(1.2, 3.4, 5.6)
    is: Array[scala.math.BigDecimal] = Array(6.5, 4.3, 2.1)
    
    scala> almostAbs(rs, is)
    res6: Array[scala.math.BigDecimal] = Array(43.69, 30.05, 35.77)
    
    scala> res6.map(d => math.sqrt(d.toDouble))
    res7: Array[Double] = Array(6.609841147864296, 5.481788029466299, 5.9808026217222725)
    

提交回复
热议问题