Spark SQL: apply aggregate functions to a list of columns

前端 未结 3 513
抹茶落季
抹茶落季 2020-11-22 10:40

Is there a way to apply an aggregate function to all (or a list of) columns of a dataframe, when doing a groupBy? In other words, is there a way to avoid doing

3条回答
  •  北海茫月
    2020-11-22 11:18

    There are multiple ways of applying aggregate functions to multiple columns.

    GroupedData class provides a number of methods for the most common functions, including count, max, min, mean and sum, which can be used directly as follows:

    • Python:

      df = sqlContext.createDataFrame(
          [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)],
          ("col1", "col2", "col3"))
      
      df.groupBy("col1").sum()
      
      ## +----+---------+-----------------+---------+
      ## |col1|sum(col1)|        sum(col2)|sum(col3)|
      ## +----+---------+-----------------+---------+
      ## | 1.0|      2.0|              0.8|      1.0|
      ## |-1.0|     -2.0|6.199999999999999|      0.7|
      ## +----+---------+-----------------+---------+
      
    • Scala

      val df = sc.parallelize(Seq(
        (1.0, 0.3, 1.0), (1.0, 0.5, 0.0),
        (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2))
      ).toDF("col1", "col2", "col3")
      
      df.groupBy($"col1").min().show
      
      // +----+---------+---------+---------+
      // |col1|min(col1)|min(col2)|min(col3)|
      // +----+---------+---------+---------+
      // | 1.0|      1.0|      0.3|      0.0|
      // |-1.0|     -1.0|      0.6|      0.2|
      // +----+---------+---------+---------+
      

    Optionally you can pass a list of columns which should be aggregated

    df.groupBy("col1").sum("col2", "col3")
    

    You can also pass dictionary / map with columns a the keys and functions as the values:

    • Python

      exprs = {x: "sum" for x in df.columns}
      df.groupBy("col1").agg(exprs).show()
      
      ## +----+---------+
      ## |col1|avg(col3)|
      ## +----+---------+
      ## | 1.0|      0.5|
      ## |-1.0|     0.35|
      ## +----+---------+
      
    • Scala

      val exprs = df.columns.map((_ -> "mean")).toMap
      df.groupBy($"col1").agg(exprs).show()
      
      // +----+---------+------------------+---------+
      // |col1|avg(col1)|         avg(col2)|avg(col3)|
      // +----+---------+------------------+---------+
      // | 1.0|      1.0|               0.4|      0.5|
      // |-1.0|     -1.0|3.0999999999999996|     0.35|
      // +----+---------+------------------+---------+
      

    Finally you can use varargs:

    • Python

      from pyspark.sql.functions import min
      
      exprs = [min(x) for x in df.columns]
      df.groupBy("col1").agg(*exprs).show()
      
    • Scala

      import org.apache.spark.sql.functions.sum
      
      val exprs = df.columns.map(sum(_))
      df.groupBy($"col1").agg(exprs.head, exprs.tail: _*)
      

    There are some other way to achieve a similar effect but these should more than enough most of the time.

    See also:

    • Multiple Aggregate operations on the same column of a spark dataframe

提交回复
热议问题