Spark 3 Typed User Defined Aggregate Function over Window

*爱你&永不变心* 提交于 2021-02-11 15:12:56

问题


I am trying to use a custom user defined aggregator over a window. When I use an untyped aggregator, the query works. However, I am unable to use typed UDAF as a window function - I get an error stating The query operator ``Project`` contains one or more unsupported expression types Aggregate, Window or Generate.

The following basic program showcases the problem. I think it could work using UserDefinedAggregateFunction rather then Aggregator, but the former is deprecated.

import scala.collection.mutable.Set
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions

case class myType(time: Long, id: String, day: String)

object MyTypedUDAF extends Aggregator[myType, Set[String], Long] {
  def zero: Set[String] = Set()
  def reduce(buffer: Set[String], row: myType): Set[String] = buffer += row.id
  def merge(b1: Set[String], b2: Set[String]): Set[String] = b1 ++= b2
  def finish(reduction: Set[String]): Long = reduction.size
  def bufferEncoder: Encoder[Set[String]] = Encoders.javaSerialization
  def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

object MyUntypedUDAF extends Aggregator[String, Set[String], Long] {
  def zero: Set[String] = Set()
  def reduce(buffer: Set[String], id: String): Set[String] = buffer += id
  def merge(b1: Set[String], b2: Set[String]): Set[String] = b1 ++= b2
  def finish(reduction: Set[String]): Long = reduction.size
  def bufferEncoder: Encoder[Set[String]] = Encoders.javaSerialization
  def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

object Example {
  def main(args: Array[String]) {
    val spark = SparkSession.builder.appName("Simple Application").getOrCreate()
    import spark.implicits._

    val df = Seq(
      (1L, "1", "0"),
      (2L, "1", "0"),
      (3L, "1", "0"),
      (1L, "2", "0"),
      (2L, "2", "0")
    ).toDF("time", "id", "day").as[myType]

    df.createOrReplaceTempView("mydf")
    println("Viewing dataframe")
    df.show()

    // Using the untyped:
    println("Using the untyped")
    spark.udf.register("myUntypedUDAF", functions.udaf(MyUntypedUDAF))

    val untypedResult = spark.sql("SELECT myUntypedUDAF(id) OVER (ORDER BY time ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as myuntypedudaf FROM mydf")
    untypedResult.show()

    // Using the typed without window
    println("Unsing the typed without window")
    val myTypedUDAF = (MyTypedUDAF.toColumn).name("myudaf")
    val typedResult = df.select(myTypedUDAF)

    // Using the typed with window
    println("Unsing the typed with window")
    val window = Window.orderBy('time).rowsBetween(Window.unboundedPreceding, Window.currentRow)
    val myTypedUDAFWithWindow = (MyTypedUDAF.toColumn over window).name("myudaf")
    val typedWindowResult = df.select(myTypedUDAFWithWindow)

    typedWindowResult.show()

    spark.stop()

  }

The output is:

Viewing dataframe
+----+---+---+
|time| id|day|
+----+---+---+
|   1|  1|  0|
|   2|  1|  0|
|   3|  1|  0|
|   1|  2|  0|
|   2|  2|  0|
+----+---+---+

Using the untyped
+-------------+
|myuntypedudaf|
+-------------+
|            1|
|            2|
|            2|
|            2|
|            2|
+-------------+

Unsing the typed without window
+------+
|myudaf|
+------+
|     2|
+------+

Unsing the typed with window
Exception in thread "main" org.apache.spark.sql.AnalysisException: 
The query operator `Project` contains one or more unsupported
expression types Aggregate, Window or Generate.
Invalid expressions: [mytypedudaf(encodeusingserializer(input[0, java.lang.Object, true], false), decodeusingserializer(input[0, binary, true], scala.collection.mutable.Set, false), boundreference()) OVER (ORDER BY `time` ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), mytypedudaf(encodeusingserializer(input[0, java.lang.Object, true], false), decodeusingserializer(input[0, binary, true], scala.collection.mutable.Set, false), boundreference())];;
'Project [mytypedudaf(MyTypedUDAF$@1c05097c, None, None, None, encodeusingserializer(input[0, java.lang.Object, true], false), decodeusingserializer(input[0, binary, true], scala.collection.mutable.Set, false), input[0, bigint, false], LongType, false, 0, 0) windowspecdefinition(time#10L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS myudaf#507]
+- Project [_1#3L AS time#10L, _2#4 AS id#11, _3#5 AS day#12]
   +- LocalRelation [_1#3L, _2#4, _3#5]

    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis(CheckAnalysis.scala:49)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis$(CheckAnalysis.scala:48)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:130)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1(CheckAnalysis.scala:656)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1$adapted(CheckAnalysis.scala:92)
    at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:177)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:92)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:89)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:130)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:156)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:153)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:68)
    at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:133)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:764)
    at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:133)
    at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:68)
    at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:66)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:58)
    at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:91)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:764)
    at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:89)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3646)
    at org.apache.spark.sql.Dataset.select(Dataset.scala:1456)
    at Example$.main(Example.scala:60)
    at Example.main(Example.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)
    at org.apache.spark.deploy.SparkSubmit.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:928)
    at org.apache.spark.deploy.SparkSubmit.doRunMain$1(SparkSubmit.scala:180)
    at org.apache.spark.deploy.SparkSubmit.submit(SparkSubmit.scala:203)
    at org.apache.spark.deploy.SparkSubmit.doSubmit(SparkSubmit.scala:90)
    at org.apache.spark.deploy.SparkSubmit$$anon$2.doSubmit(SparkSubmit.scala:1007)
    at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:1016)
    at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

What am I doing wrong?


回答1:


Try to use registered UserDefinedFunction to call your aggregation over window. For that you can use the same method as for MyUntypedUDAF:

val mMyTypedUDAFUDF: UserDefinedFunction = spark.udf.register(
  "myUU", 
  functions.udaf(MyTypedUDAF)
)

and call it for aggregate over your window.

df.select(mMyTypedUDAFUDF.apply('time, 'id, 'day).over(window).name("myudaf"))
  .show()
// output:
+------+
|myudaf|
+------+
|     1|
|     2|
|     2|
|     2|
|     2|
+------+


来源:https://stackoverflow.com/questions/65156320/spark-3-typed-user-defined-aggregate-function-over-window

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!