Pyspark : Custom window function

前端 未结 2 680
时光取名叫无心
时光取名叫无心 2021-02-10 08:12

I am currently trying to extract series of consecutive occurrences in a PySpark dataframe and order/rank them as shown below (for convenience I have ordered the initial datafram

相关标签:
2条回答
  • 2021-02-10 09:02

    This is a pretty common pattern and can be expressed using window functions in a few steps. First import required functions:

    from pyspark.sql.functions import sum as sum_, lag, col, coalesce, lit
    from pyspark.sql.window import Window
    

    Next define a window:

    w = Window.partitionBy("user_id").orderBy("timestamp")
    

    Mark first row for each group:

    is_first = coalesce(
      (lag("actions", 1).over(w) != col("actions")).cast("bigint"),
      lit(1)
    )
    

    Define order:

    order = sum_("is_first").over(w)
    

    And combine all part together with an aggregation:

    (df
        .withColumn("is_first", is_first)
        .withColumn("order", order)
        .groupBy("user_id", "actions", "order")
        .count())
    

    If you define df as:

    df = sc.parallelize([
        (217498, 100000001, 'A'), (217498, 100000025, 'A'), (217498, 100000124, 'A'),
        (217498, 100000152, 'B'), (217498, 100000165, 'C'), (217498, 100000177, 'C'),
        (217498, 100000182, 'A'), (217498, 100000197, 'B'), (217498, 100000210, 'B'),
        (854123, 100000005, 'A'), (854123, 100000007, 'A')
    ]).toDF(["user_id", "timestamp", "actions"])
    

    and order the result by user_id and order you'll get:

    +-------+-------+-----+-----+ 
    |user_id|actions|order|count|
    +-------+-------+-----+-----+
    | 217498|      A|    1|    3|
    | 217498|      B|    2|    1|
    | 217498|      C|    3|    2|
    | 217498|      A|    4|    1|
    | 217498|      B|    5|    2|
    | 854123|      A|    1|    2|
    +-------+-------+-----+-----+
    
    0 讨论(0)
  • 2021-02-10 09:04

    I'm afraid it is not possible using standard dataframe windowing functions. But you can still use old RDD API groupByKey() to achieve that transformation:

    >>> from itertools import groupby
    >>> 
    >>> def recalculate(records):
    ...     actions = [r.actions for r in sorted(records[1], key=lambda r: r.timestamp)]
    ...     groups = [list(g) for k, g in groupby(actions)]
    ...     return [(records[0], g[0], len(g), i+1) for i, g in enumerate(groups)]
    ... 
    >>> df_ini.rdd.map(lambda row: (row.user_id, row)) \
    ...     .groupByKey().flatMap(recalculate) \
    ...     .toDF(['user_id', 'actions', 'nf_of_occ', 'order']).show()
    +-------+-------+---------+-----+
    |user_id|actions|nf_of_occ|order|
    +-------+-------+---------+-----+
    | 217498|      A|        3|    1|
    | 217498|      B|        1|    2|
    | 217498|      C|        2|    3|
    | 217498|      A|        1|    4|
    | 217498|      B|        2|    5|
    | 854123|      A|        2|    1|
    +-------+-------+---------+-----+
    
    0 讨论(0)
提交回复
热议问题