How to define a custom aggregation function to sum a column of Vectors?

前端 未结 2 1467
闹比i
闹比i 2020-11-27 04:23

I have a DataFrame of two columns, ID of type Int and Vec of type Vector (org.apache.spark.mllib.linalg.Vector

相关标签:
2条回答
  • 2020-11-27 05:03

    I suggest the following (works on Spark 2.0.2 onward), it might be optimized but it's very nice, one thing you have to know in advance is the vector size when you create the UDAF instance

    import org.apache.spark.ml.linalg._
    import org.apache.spark.mllib.linalg.WeightedSparseVector
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    class VectorAggregate(val numFeatures: Int)
       extends UserDefinedAggregateFunction {
    
    private type B = Map[Int, Double]
    
    def inputSchema: StructType = StructType(StructField("vec", new VectorUDT()) :: Nil)
    
    def bufferSchema: StructType =
    StructType(StructField("agg", MapType(IntegerType, DoubleType)) :: Nil)
    
    def initialize(buffer: MutableAggregationBuffer): Unit =
    buffer.update(0, Map.empty[Int, Double])
    
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val zero = buffer.getAs[B](0)
        input match {
            case Row(DenseVector(values)) => buffer.update(0, values.zipWithIndex.foldLeft(zero){case (acc,(v,i)) => acc.updated(i, v + acc.getOrElse(i,0d))})
            case Row(SparseVector(_, indices, values)) => buffer.update(0, values.zip(indices).foldLeft(zero){case (acc,(v,i)) => acc.updated(i, v + acc.getOrElse(i,0d))}) }}
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val zero = buffer1.getAs[B](0)
    buffer1.update(0, buffer2.getAs[B](0).foldLeft(zero){case (acc,(i,v)) => acc.updated(i, v + acc.getOrElse(i,0d))})}
    
    def deterministic: Boolean = true
    
    def evaluate(buffer: Row): Any = {
        val Row(agg: B) = buffer
        val indices = agg.keys.toArray.sorted
        Vectors.sparse(numFeatures,indices,indices.map(agg)).compressed
    }
    
    def dataType: DataType = new VectorUDT()
    }
    
    0 讨论(0)
  • 2020-11-27 05:06

    Spark >= 3.0

    You can use Summarizer with sum

    import org.apache.spark.ml.stat.Summarizer
    
    df
      .groupBy($"id")
      .agg(Summarizer.sum($"vec").alias("vec"))
    

    Spark <= 3.0

    Personally I wouldn't bother with UDAFs. There are more than verbose and not exactly fast (Spark UDAF with ArrayType as bufferSchema performance issues) Instead I would simply use reduceByKey / foldByKey:

    import org.apache.spark.sql.Row
    import breeze.linalg.{DenseVector => BDV}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    
    def dv(values: Double*): Vector = Vectors.dense(values.toArray)
    
    val df = spark.createDataFrame(Seq(
        (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)),
        (2, dv(7,5,0)), (2, dv(3,3,4)), 
        (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7)))
      ).toDF("id", "vec")
    
    val aggregated = df
      .rdd
      .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) }
      .foldByKey(BDV.zeros[Double](3))(_ += _)
      .mapValues(v => Vectors.dense(v.toArray))
      .toDF("id", "vec")
    
    aggregated.show
    
    // +---+--------------+
    // | id|           vec|
    // +---+--------------+
    // |  1| [5.0,2.0,7.0]|
    // |  2|[10.0,8.0,4.0]|
    // |  3|[7.0,15.0,9.0]|
    // +---+--------------+
    

    And just for comparison a "simple" UDAF. Required imports:

    import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
      UserDefinedAggregateFunction}
    import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
    import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType}
    import org.apache.spark.sql.Row
    import scala.collection.mutable.WrappedArray
    

    Class definition:

    class VectorSum (n: Int) extends UserDefinedAggregateFunction {
        def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
        def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
        def dataType = SQLDataTypes.VectorType
        def deterministic = true 
    
        def initialize(buffer: MutableAggregationBuffer) = {
          buffer.update(0, Array.fill(n)(0.0))
        }
    
        def update(buffer: MutableAggregationBuffer, input: Row) = {
          if (!input.isNullAt(0)) {
            val buff = buffer.getAs[WrappedArray[Double]](0) 
            val v = input.getAs[Vector](0).toSparse
            for (i <- v.indices) {
              buff(i) += v(i)
            }
            buffer.update(0, buff)
          }
        }
    
        def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
          val buff1 = buffer1.getAs[WrappedArray[Double]](0) 
          val buff2 = buffer2.getAs[WrappedArray[Double]](0) 
          for ((x, i) <- buff2.zipWithIndex) {
            buff1(i) += x
          }
          buffer1.update(0, buff1)
        }
    
        def evaluate(buffer: Row) =  Vectors.dense(
          buffer.getAs[Seq[Double]](0).toArray)
    } 
    

    And an example usage:

    df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show
    
    // +---+--------------+
    // | id|           vec|
    // +---+--------------+
    // |  1| [5.0,2.0,7.0]|
    // |  2|[10.0,8.0,4.0]|
    // |  3|[7.0,15.0,9.0]|
    // +---+--------------+
    

    See also: How to find mean of grouped Vector columns in Spark SQL?.

    0 讨论(0)
提交回复
热议问题