How to calculate group-based quantiles?

后端 未结 1 966
一个人的身影
一个人的身影 2020-12-07 06:50

I\'m using spark-sql-2.4.1v, and I\'m trying to do find quantiles, i.e. percentile 0, percentile 25, etc, on each column of my given data.

My dataframe df

相关标签:
1条回答
  • 2020-12-07 07:05

    The possible solution could be:

    scala> input.show
    +---+---------+-----------+----------+-----------+-----+
    | id|     date|    revenue|con_dist_1| con_dist_2|state|
    +---+---------+-----------+----------+-----------+-----+
    | 10|1/15/2018|0.010680705|         6|0.019875458|   TX|
    | 10|1/15/2018|0.006628853|         4|0.816039063|   AZ|
    | 10|1/15/2018| 0.01378215|         4|0.082049528|   TX|
    | 10|1/15/2018|0.010680705|         6|0.019875458|   TX|
    | 10|1/15/2018|0.006628853|         4|0.816039063|   AZ|
    +---+---------+-----------+----------+-----------+-----+
    
    scala> val df1 = input.groupBy("state").agg(collect_list("con_dist_1").as("combined_1"), collect_list("con_dist_2").as("combined_2"))
    df1: org.apache.spark.sql.DataFrame = [state: string, combined_1: array<int> ... 1 more field]
    
    scala> df1.show
    +-----+----------+--------------------+                                         
    |state|combined_1|          combined_2|
    +-----+----------+--------------------+
    |   AZ|    [4, 4]|[0.816039063, 0.8...|
    |   TX| [6, 4, 6]|[0.019875458, 0.0...|
    +-----+----------+--------------------+
    
    scala> df1.
         | withColumn("comb1_Q1", sort_array($"combined_1")(((size($"combined_1")-1)*0.25).cast("int"))).
         | withColumn("comb1_Q2", sort_array($"combined_1")(((size($"combined_1")-1)*0.5).cast("int"))).
         | withColumn("comb1_Q3", sort_array($"combined_1")(((size($"combined_1")-1)*0.75).cast("int"))).
         | withColumn("comb_2_Q1", sort_array($"combined_2")(((size($"combined_2")-1)*0.25).cast("int"))).
         | withColumn("comb_2_Q2", sort_array($"combined_2")(((size($"combined_2")-1)*0.5).cast("int"))).
         | withColumn("comb_2_Q3", sort_array($"combined_2")(((size($"combined_2")-1)*0.75).cast("int"))).
         | show
    +-----+----------+--------------------+--------+--------+--------+-----------+-----------+-----------+
    |state|combined_1|          combined_2|comb1_Q1|comb1_Q2|comb1_Q3|  comb_2_Q1|  comb_2_Q2|  comb_2_Q3|
    +-----+----------+--------------------+--------+--------+--------+-----------+-----------+-----------+
    |   AZ|    [4, 4]|[0.816039063, 0.8...|       4|       4|       4|0.816039063|0.816039063|0.816039063|
    |   TX| [6, 4, 6]|[0.019875458, 0.0...|       4|       6|       6|0.019875458|0.019875458|0.019875458|
    +-----+----------+--------------------+--------+--------+--------+-----------+-----------+-----------+
    

    EDIT

    I don't think we can achieve using approx quantile method as you want it for each state for which you will need to group by on state column and aggregate the con_dist columns and approx quantile expects a whole column of integers or float but not of array types.

    The other solution is to use spark-sql as shown below:

    scala> input.show
    +---+---------+-----------+----------+-----------+-----+
    | id|     date|    revenue|con_dist_1| con_dist_2|state|
    +---+---------+-----------+----------+-----------+-----+
    | 10|1/15/2018|0.010680705|         6|0.019875458|   TX|
    | 10|1/15/2018|0.006628853|         4|0.816039063|   AZ|
    | 10|1/15/2018| 0.01378215|         4|0.082049528|   TX|
    | 10|1/15/2018|0.010680705|         6|0.019875458|   TX|
    | 10|1/15/2018|0.006628853|         4|0.816039063|   AZ|
    +---+---------+-----------+----------+-----------+-----+
    
    
    scala> input.createOrReplaceTempView("input")
    
    scala> :paste
    // Entering paste mode (ctrl-D to finish)
    
    val query = "select state, percentile_approx(con_dist_1,0.25) as col1_quantile_1, " +
      "percentile_approx(con_dist_1,0.5) as col1_quantile_2," +
      "percentile_approx(con_dist_1,0.75) as col1_quantile_3, " +
      "percentile_approx(con_dist_2,0.25) as col2_quantile_1,"+
      "percentile_approx(con_dist_2,0.5) as col2_quantile_2," +
      "percentile_approx(con_dist_2,0.75) as col2_quantile_3 " +
      "from input group by state"
    
    // Exiting paste mode, now interpreting.
    
    query: String = select state, percentile_approx(con_dist_1,0.25) as col1_quantile_1, percentile_approx(con_dist_1,0.5) as col1_quantile_2,percentile_approx(con_dist_1,0.75) as col1_quantile_3, percentile_approx(con_dist_2,0.25) as col2_quantile_1,percentile_approx(con_dist_2,0.5) as col2_quantile_2,percentile_approx(con_dist_2,0.75) as col2_quantile_3 from input group by state
    
    scala> val df2 = spark.sql(query)
    df2: org.apache.spark.sql.DataFrame = [state: string, col1_quantile_1: int ... 5 more fields]
    
    scala> df2.show
    +-----+---------------+---------------+---------------+---------------+---------------+---------------+
    |state|col1_quantile_1|col1_quantile_2|col1_quantile_3|col2_quantile_1|col2_quantile_2|col2_quantile_3|
    +-----+---------------+---------------+---------------+---------------+---------------+---------------+
    |   AZ|              4|              4|              4|    0.816039063|    0.816039063|    0.816039063|
    |   TX|              4|              6|              6|    0.019875458|    0.019875458|    0.082049528|
    +-----+---------------+---------------+---------------+---------------+---------------+---------------+
    

    Let me know if it helps!!

    0 讨论(0)
提交回复
热议问题