We have dataframe like below :
+------+--------------------+
| Flag | value|
+------+--------------------+
|1 |5 |
|1 |4
It's probably best to do with pandas_udf
here.
from pyspark.sql.functions import pandas_udf, PandasUDFType
pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
df = spark.createDataFrame(pdf)
df = df.withColumn('cumsum', F.lit(math.inf))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def _calc_cumsum(pdf):
pdf.sort_values(by=['id'], inplace=True, ascending=True)
cumsums = []
prev = None
reset = False
for v in pdf['value'].values:
if prev is None:
cumsums.append(v)
prev = v
else:
prev = prev + v if not reset else v
cumsums.append(prev)
reset = True if prev >= 20 else False
pdf['cumsum'] = cumsums
return pdf
df = df.groupby('flag').apply(_calc_cumsum)
df.show()
the results:
+----+---+-----+------+
|flag| id|value|cumsum|
+----+---+-----+------+
| 1| 0| 5| 5.0|
| 1| 1| 4| 9.0|
| 1| 2| 3| 12.0|
| 1| 3| 5| 17.0|
| 1| 4| 6| 23.0|
| 1| 5| 4| 4.0|
| 1| 6| 7| 11.0|
| 1| 7| 5| 16.0|
| 1| 8| 2| 18.0|
| 1| 9| 3| 21.0|
| 1| 10| 2| 2.0|
| 1| 11| 6| 8.0|
| 1| 12| 9| 17.0|
+----+---+-----+------+
The only way to do without udf
it's using high order functions:
Click here to see step by step on Databricks (valid until 30/06/2021)
from pyspark.sql import Row
from pyspark.sql.window import Window
import pyspark.sql.functions as f
df = spark.createDataFrame(
[Row(Flag=1, value=5), Row(Flag=1, value=4), Row(Flag=1, value=3), Row(Flag=1, value=5), Row(Flag=1, value=6),
Row(Flag=1, value=4), Row(Flag=1, value=7), Row(Flag=1, value=5), Row(Flag=1, value=2), Row(Flag=1, value=3),
Row(Flag=1, value=2), Row(Flag=1, value=6), Row(Flag=1, value=9)]
)
window = Window.partitionBy('flag')
df = df.withColumn('row_id', f.row_number().over(window.orderBy('flag')).cast('int'))
df = df.withColumn('values', f.collect_list('value').over(window).cast('array<int>'))
expr = "TRANSFORM(slice(values, 1, row_id), sliced_array -> sliced_array)"
df = df.withColumn('sliced_array', f.expr(expr))
expr = "REDUCE(sliced_array, 0, (c, n) -> IF(c < 20, c + n, n))"
df = df.select('flag', 'value', f.expr(expr).alias('cumsum'))
df.show()
Output:
+----+-----+------+
|flag|value|cumsum|
+----+-----+------+
| 1| 5| 5|
| 1| 4| 9|
| 1| 3| 12|
| 1| 5| 17|
| 1| 6| 23|
| 1| 4| 4|
| 1| 7| 11|
| 1| 5| 16|
| 1| 2| 18|
| 1| 3| 21|
| 1| 2| 2|
| 1| 6| 8|
| 1| 9| 17|
+----+-----+------+