How to pivot Spark DataFrame?

后端 未结 10 2136
闹比i 2020-11-21 06:43

I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality

  •  失恋的感觉
    2020-11-21 07:09

    The built-in spark pivot function is inefficient. The bellow implementation works on spark 2.4+ - the idea is to aggregate a map and extract the values as columns. The only limitation is it does not handle aggregate function in the pivoted columns, only column(s).

    On a 8M table, those functions applies on 3 secondes, versus 40 minutes in the built-in spark version:

    # pass an optional list of string to avoid computation of columns
    def pivot(df, group_by, key, aggFunction, levels=[]):
        if not levels:
            levels = [row[key] for row in df.filter(col(key).isNotNull()).groupBy(col(key)).agg(count(key)).select(key).collect()]
        return df.filter(col(key).isin(*levels) == True).groupBy(group_by).agg(map_from_entries(collect_list(struct(key, expr(aggFunction)))).alias("group_map")).select([group_by] + ["group_map." + l for l in levels])
    # Usage
    pivot(df, "id", "key", "value")
    pivot(df, "id", "key", "array(value)")
    // pass an optional list of string to avoid computation of columns
      def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = {
        val levels =
          if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList
          else _levels
          .agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map"))
          .select(groupBy.toString, => "group_map." + f): _*)
    // Usage:
    pivot(df, col("id"), col("key"), "value")
    pivot(df, col("id"), col("key"), "array(value)")
