I have a spark dataframe that contains sales prediction data for some products in some stores over a time period. How do I calculate the rolling sum of Predictions for a window
If you're using spark 2.4+, you can use the new higher-order array functions slice
and aggregate
to efficiently implement your requirement without any UDFs:
summed_predictions = predictions\
.withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
.withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))
summed_predictions.show()
+---------+-------+-------------------+----------+---+------------------+
|ProductId|StoreId| Date|Prediction| N| summed|
+---------+-------+-------------------+----------+---+------------------+
| 1| 100|2019-07-01 00:00:00| 0.92| 2| 1.54|
| 1| 100|2019-07-02 00:00:00| 0.62| 2| 1.51|
| 1| 100|2019-07-03 00:00:00| 0.89| 2| 1.46|
| 1| 100|2019-07-04 00:00:00| 0.57| 2| 0.57|
| 2| 200|2019-07-01 00:00:00| 1.39| 3| 3.94|
| 2| 200|2019-07-02 00:00:00| 1.22| 3| 4.16|
| 2| 200|2019-07-03 00:00:00| 1.33| 3|2.9400000000000004|
| 2| 200|2019-07-04 00:00:00| 1.61| 3| 1.61|
+---------+-------+-------------------+----------+---+------------------+