How to calculate rolling sum with varying window sizes in PySpark

前端 未结 2 1218
隐瞒了意图╮
隐瞒了意图╮ 2021-02-10 04:21

I have a spark dataframe that contains sales prediction data for some products in some stores over a time period. How do I calculate the rolling sum of Predictions for a window

相关标签:
2条回答
  • 2021-02-10 04:39

    It might not be the best, but you can get distinct "N" column values and loop like below.

    val arr = df.select("N").distinct.collect
    
    for(n <- arr) df.filter(col("N") ===  n.get(0))
    .withColumn("RollingSum",sum(col("Prediction"))
    .over(Window.partitionBy("N").orderBy("N").rowsBetween(Window.currentRow, n.get(0).toString.toLong-1))).show
    

    This will give you like:

    +---------+-------+----------+----------+---+------------------+
    |ProductId|StoreId|      Date|Prediction|  N|        RollingSum|
    +---------+-------+----------+----------+---+------------------+
    |        2|    200|2019-07-01|      1.39|  3|              3.94|
    |        2|    200|2019-07-02|      1.22|  3|              4.16|
    |        2|    200|2019-07-03|      1.33|  3|2.9400000000000004|
    |        2|    200|2019-07-04|      1.61|  3|              1.61|
    +---------+-------+----------+----------+---+------------------+
    
    +---------+-------+----------+----------+---+----------+
    |ProductId|StoreId|      Date|Prediction|  N|RollingSum|
    +---------+-------+----------+----------+---+----------+
    |        1|    100|2019-07-01|      0.92|  2|      1.54|
    |        1|    100|2019-07-02|      0.62|  2|      1.51|
    |        1|    100|2019-07-03|      0.89|  2|      1.46|
    |        1|    100|2019-07-04|      0.57|  2|      0.57|
    +---------+-------+----------+----------+---+----------+
    

    Then you can do a union of all the dataframes inside the loop.

    0 讨论(0)
  • 2021-02-10 04:44

    If you're using spark 2.4+, you can use the new higher-order array functions slice and aggregate to efficiently implement your requirement without any UDFs:

    summed_predictions = predictions\
       .withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
       .withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))
    
    summed_predictions.show()
    +---------+-------+-------------------+----------+---+------------------+
    |ProductId|StoreId|               Date|Prediction|  N|            summed|
    +---------+-------+-------------------+----------+---+------------------+
    |        1|    100|2019-07-01 00:00:00|      0.92|  2|              1.54|
    |        1|    100|2019-07-02 00:00:00|      0.62|  2|              1.51|
    |        1|    100|2019-07-03 00:00:00|      0.89|  2|              1.46|
    |        1|    100|2019-07-04 00:00:00|      0.57|  2|              0.57|
    |        2|    200|2019-07-01 00:00:00|      1.39|  3|              3.94|
    |        2|    200|2019-07-02 00:00:00|      1.22|  3|              4.16|
    |        2|    200|2019-07-03 00:00:00|      1.33|  3|2.9400000000000004|
    |        2|    200|2019-07-04 00:00:00|      1.61|  3|              1.61|
    +---------+-------+-------------------+----------+---+------------------+
    
    0 讨论(0)
提交回复
热议问题