I am currently trying to extract series of consecutive occurrences in a PySpark dataframe and order/rank them as shown below (for convenience I have ordered the initial datafram
This is a pretty common pattern and can be expressed using window functions in a few steps. First import required functions:
from pyspark.sql.functions import sum as sum_, lag, col, coalesce, lit
from pyspark.sql.window import Window
Next define a window:
w = Window.partitionBy("user_id").orderBy("timestamp")
Mark first row for each group:
is_first = coalesce(
(lag("actions", 1).over(w) != col("actions")).cast("bigint"),
Define order
order = sum_("is_first").over(w)
And combine all part together with an aggregation:
.withColumn("is_first", is_first)
.withColumn("order", order)
.groupBy("user_id", "actions", "order")
If you define df
df = sc.parallelize([
(217498, 100000001, 'A'), (217498, 100000025, 'A'), (217498, 100000124, 'A'),
(217498, 100000152, 'B'), (217498, 100000165, 'C'), (217498, 100000177, 'C'),
(217498, 100000182, 'A'), (217498, 100000197, 'B'), (217498, 100000210, 'B'),
(854123, 100000005, 'A'), (854123, 100000007, 'A')
]).toDF(["user_id", "timestamp", "actions"])
and order the result by user_id
and order
you'll get:
| 217498| A| 1| 3|
| 217498| B| 2| 1|
| 217498| C| 3| 2|
| 217498| A| 4| 1|
| 217498| B| 5| 2|
| 854123| A| 1| 2|