Rewrite LogicalPlan to push down udf from aggregate

前端 未结 1 1868
谎友^
谎友^ 2021-01-07 04:47

I have defined an UDF which increases the input value by one, named \"inc\", this is the code of my udf

spark.udf.r         


        
相关标签:
1条回答
  • 2021-01-07 05:13

    OK, finally I find way to so answer this question.

    Though ScalaUDF can't cast to NamedExpression, but Alias could.

    So, I create Alias from ScalaUDF, then construct Project.

    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.catalyst.InternalRow
    import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
    import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
    import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
    import org.apache.spark.sql.catalyst.rules.Rule
    import org.apache.spark.sql.types.{AbstractDataType, DataType}
    
    import scala.collection.mutable
    
    object RewritePlanTest {
    
      case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
    
        def collectUDFs(e: Expression): Seq[Expression] = e match {
          case udf: ScalaUDF => Seq(udf)
          case _ => e.children.flatMap(collectUDFs)
        }
    
        override def apply(plan: LogicalPlan): LogicalPlan = plan match {
          case agg@Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
            val udfs = agg.expressions.flatMap(collectUDFs)
            if (udfs.isEmpty) {
              agg
            } else {
              val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
              val alias_set = mutable.HashMap[Expression, Attribute]()
              val proj = Project(alias_udf, c)
              alias_set ++= udfs.zip(proj.output)
              val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
                case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
              }
              println("====== new agg ======")
              println(new_agg)
              new_agg
            }
          }
          case _ => plan
        }
      }
    
    
      def main(args: Array[String]): Unit = {
        Logger.getLogger("org").setLevel(Level.WARN)
    
        val spark = SparkSession
          .builder()
          .master("local[*]")
          .appName("Rewrite plan test")
          .withExtensions(e => e.injectOptimizerRule(UdfRule))
          .getOrCreate()
    
        val input = Seq(100L, 200L, 300L)
        import spark.implicits._
        input.toDF("vals").createOrReplaceTempView("data")
    
        spark.udf.register("inc", (x: Long) => x + 1)
    
        val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
        //    val plan = df.queryExecution.analyzed
        //    println(plan)
        df.explain(true)
        df.show()
    
        spark.stop()
    
      }
    }
    

    This code output the LogicalPlan that I wanted.

    ====== new agg ======
    Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
    +- Project [inc(vals#4L) AS udf0#9L]
       +- LocalRelation [vals#4L]
    
    0 讨论(0)
提交回复
热议问题