We have dataframe like below :
+------+--------------------+
| Flag | value|
+------+--------------------+
|1 |5 |
|1 |4
It's probably best to do with pandas_udf
here.
from pyspark.sql.functions import pandas_udf, PandasUDFType
pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
df = spark.createDataFrame(pdf)
df = df.withColumn('cumsum', F.lit(math.inf))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def _calc_cumsum(pdf):
pdf.sort_values(by=['id'], inplace=True, ascending=True)
cumsums = []
prev = None
reset = False
for v in pdf['value'].values:
if prev is None:
cumsums.append(v)
prev = v
else:
prev = prev + v if not reset else v
cumsums.append(prev)
reset = True if prev >= 20 else False
pdf['cumsum'] = cumsums
return pdf
df = df.groupby('flag').apply(_calc_cumsum)
df.show()
the results:
+----+---+-----+------+
|flag| id|value|cumsum|
+----+---+-----+------+
| 1| 0| 5| 5.0|
| 1| 1| 4| 9.0|
| 1| 2| 3| 12.0|
| 1| 3| 5| 17.0|
| 1| 4| 6| 23.0|
| 1| 5| 4| 4.0|
| 1| 6| 7| 11.0|
| 1| 7| 5| 16.0|
| 1| 8| 2| 18.0|
| 1| 9| 3| 21.0|
| 1| 10| 2| 2.0|
| 1| 11| 6| 8.0|
| 1| 12| 9| 17.0|
+----+---+-----+------+