I have a spark dataframe that contains sales prediction data for some products in some stores over a time period. How do I calculate the rolling sum of Predictions for a window
It might not be the best, but you can get distinct "N" column values and loop like below.
val arr = df.select("N").distinct.collect
for(n <- arr) df.filter(col("N") === n.get(0))
.withColumn("RollingSum",sum(col("Prediction"))
.over(Window.partitionBy("N").orderBy("N").rowsBetween(Window.currentRow, n.get(0).toString.toLong-1))).show
This will give you like:
+---------+-------+----------+----------+---+------------------+
|ProductId|StoreId| Date|Prediction| N| RollingSum|
+---------+-------+----------+----------+---+------------------+
| 2| 200|2019-07-01| 1.39| 3| 3.94|
| 2| 200|2019-07-02| 1.22| 3| 4.16|
| 2| 200|2019-07-03| 1.33| 3|2.9400000000000004|
| 2| 200|2019-07-04| 1.61| 3| 1.61|
+---------+-------+----------+----------+---+------------------+
+---------+-------+----------+----------+---+----------+
|ProductId|StoreId| Date|Prediction| N|RollingSum|
+---------+-------+----------+----------+---+----------+
| 1| 100|2019-07-01| 0.92| 2| 1.54|
| 1| 100|2019-07-02| 0.62| 2| 1.51|
| 1| 100|2019-07-03| 0.89| 2| 1.46|
| 1| 100|2019-07-04| 0.57| 2| 0.57|
+---------+-------+----------+----------+---+----------+
Then you can do a union of all the dataframes inside the loop.
If you're using spark 2.4+, you can use the new higher-order array functions slice
and aggregate
to efficiently implement your requirement without any UDFs:
summed_predictions = predictions\
.withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
.withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))
summed_predictions.show()
+---------+-------+-------------------+----------+---+------------------+
|ProductId|StoreId| Date|Prediction| N| summed|
+---------+-------+-------------------+----------+---+------------------+
| 1| 100|2019-07-01 00:00:00| 0.92| 2| 1.54|
| 1| 100|2019-07-02 00:00:00| 0.62| 2| 1.51|
| 1| 100|2019-07-03 00:00:00| 0.89| 2| 1.46|
| 1| 100|2019-07-04 00:00:00| 0.57| 2| 0.57|
| 2| 200|2019-07-01 00:00:00| 1.39| 3| 3.94|
| 2| 200|2019-07-02 00:00:00| 1.22| 3| 4.16|
| 2| 200|2019-07-03 00:00:00| 1.33| 3|2.9400000000000004|
| 2| 200|2019-07-04 00:00:00| 1.61| 3| 1.61|
+---------+-------+-------------------+----------+---+------------------+