Suppose I have a DataFrame of events with time difference between each row, the main rule is that one visit is counted if only the event has been within 5 minutes of the previ
So if I understand this correctly you essentially want to end each group when TimeDiff > 300? This seems relatively straightforward with rolling window functions:
First some imports
from pyspark.sql.window import Window
import pyspark.sql.functions as func
Then setting windows, I assumed you would partition by userid
w = Window.partitionBy("userid").orderBy("eventtime")
Then figuring out what subgroup each observation falls into, by first marking the first member of each group, then summing the column.
indicator = (TimeDiff > 300).cast("integer")
subgroup = func.sum(indicator).over(w).alias("subgroup")
Then some aggregation functions and you should be done
DF = DF.select("*", subgroup)\
.groupBy("subgroup")\
.agg(
func.min("eventtime").alias("start_time"),
func.max("eventtime").alias("end_time"),
func.count(func.lit(1)).alias("events")
)