Pyspark : Cumulative Sum with reset condition

后端 未结 2 2045
盖世英雄少女心
盖世英雄少女心 2021-02-09 04:49

We have dataframe like below :

+------+--------------------+
| Flag |               value|
+------+--------------------+
|1     |5                   |
|1     |4          


        
相关标签:
2条回答
  • 2021-02-09 05:33

    It's probably best to do with pandas_udf here.

    from pyspark.sql.functions import pandas_udf, PandasUDFType
    
    pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
    df = spark.createDataFrame(pdf)
    df = df.withColumn('cumsum', F.lit(math.inf))
    
    @pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
    def _calc_cumsum(pdf):
        pdf.sort_values(by=['id'], inplace=True, ascending=True)
        cumsums = []
        prev = None
        reset = False
        for v in pdf['value'].values:
            if prev is None:
                cumsums.append(v)
                prev = v
            else:
                prev = prev + v if not reset else v
                cumsums.append(prev)
                reset = True if prev >= 20 else False
                
        pdf['cumsum'] = cumsums
        return pdf
    
    df = df.groupby('flag').apply(_calc_cumsum)
    df.show()
    

    the results:

    +----+---+-----+------+
    |flag| id|value|cumsum|
    +----+---+-----+------+
    |   1|  0|    5|   5.0|
    |   1|  1|    4|   9.0|
    |   1|  2|    3|  12.0|
    |   1|  3|    5|  17.0|
    |   1|  4|    6|  23.0|
    |   1|  5|    4|   4.0|
    |   1|  6|    7|  11.0|
    |   1|  7|    5|  16.0|
    |   1|  8|    2|  18.0|
    |   1|  9|    3|  21.0|
    |   1| 10|    2|   2.0|
    |   1| 11|    6|   8.0|
    |   1| 12|    9|  17.0|
    +----+---+-----+------+
    
    
    0 讨论(0)
  • The only way to do without udf it's using high order functions:

    Click here to see step by step on Databricks (valid until 30/06/2021)

    Spark >= 2.4.x

    from pyspark.sql import Row
    from pyspark.sql.window import Window
    import pyspark.sql.functions as f
    
    
    df = spark.createDataFrame(
        [Row(Flag=1, value=5), Row(Flag=1, value=4), Row(Flag=1, value=3), Row(Flag=1, value=5), Row(Flag=1, value=6),
         Row(Flag=1, value=4), Row(Flag=1, value=7), Row(Flag=1, value=5), Row(Flag=1, value=2), Row(Flag=1, value=3),
         Row(Flag=1, value=2), Row(Flag=1, value=6), Row(Flag=1, value=9)]
    )
    
    window = Window.partitionBy('flag')
    df = df.withColumn('row_id', f.row_number().over(window.orderBy('flag')).cast('int'))
    df = df.withColumn('values', f.collect_list('value').over(window).cast('array<int>'))
    
    expr = "TRANSFORM(slice(values, 1, row_id), sliced_array -> sliced_array)"
    df = df.withColumn('sliced_array', f.expr(expr))
    
    expr = "REDUCE(sliced_array, 0, (c, n) -> IF(c < 20, c + n, n))"
    df = df.select('flag', 'value', f.expr(expr).alias('cumsum'))
    
    df.show()
    

    Output:

    +----+-----+------+
    |flag|value|cumsum|
    +----+-----+------+
    |   1|    5|     5|
    |   1|    4|     9|
    |   1|    3|    12|
    |   1|    5|    17|
    |   1|    6|    23|
    |   1|    4|     4|
    |   1|    7|    11|
    |   1|    5|    16|
    |   1|    2|    18|
    |   1|    3|    21|
    |   1|    2|     2|
    |   1|    6|     8|
    |   1|    9|    17|
    +----+-----+------+
    
    0 讨论(0)
提交回复
热议问题