Pyspark : Cumulative Sum with reset condition

后端 未结 2 2046
盖世英雄少女心
盖世英雄少女心 2021-02-09 04:49

We have dataframe like below :

+------+--------------------+
| Flag |               value|
+------+--------------------+
|1     |5                   |
|1     |4          


        
2条回答
  •  醉酒成梦
    2021-02-09 05:33

    It's probably best to do with pandas_udf here.

    from pyspark.sql.functions import pandas_udf, PandasUDFType
    
    pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
    df = spark.createDataFrame(pdf)
    df = df.withColumn('cumsum', F.lit(math.inf))
    
    @pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
    def _calc_cumsum(pdf):
        pdf.sort_values(by=['id'], inplace=True, ascending=True)
        cumsums = []
        prev = None
        reset = False
        for v in pdf['value'].values:
            if prev is None:
                cumsums.append(v)
                prev = v
            else:
                prev = prev + v if not reset else v
                cumsums.append(prev)
                reset = True if prev >= 20 else False
                
        pdf['cumsum'] = cumsums
        return pdf
    
    df = df.groupby('flag').apply(_calc_cumsum)
    df.show()
    

    the results:

    +----+---+-----+------+
    |flag| id|value|cumsum|
    +----+---+-----+------+
    |   1|  0|    5|   5.0|
    |   1|  1|    4|   9.0|
    |   1|  2|    3|  12.0|
    |   1|  3|    5|  17.0|
    |   1|  4|    6|  23.0|
    |   1|  5|    4|   4.0|
    |   1|  6|    7|  11.0|
    |   1|  7|    5|  16.0|
    |   1|  8|    2|  18.0|
    |   1|  9|    3|  21.0|
    |   1| 10|    2|   2.0|
    |   1| 11|    6|   8.0|
    |   1| 12|    9|  17.0|
    +----+---+-----+------+
    
    

提交回复
热议问题