Pyspark window function with condition

后端 未结 3 599
广开言路
广开言路 2021-02-08 23:13

Suppose I have a DataFrame of events with time difference between each row, the main rule is that one visit is counted if only the event has been within 5 minutes of the previ

相关标签:
3条回答
  • 2021-02-08 23:45

    Approach can be grouping the dataframe based on your timeline criteria.

    You can create a dataframe with the rows breaking the 5 minutes timeline. Those rows are criteria for grouping the records and that rows will set the startime and endtime for each group.

    Then find the count and max timestamp(endtime) for each group.

    0 讨论(0)
  • 2021-02-08 23:56

    So if I understand this correctly you essentially want to end each group when TimeDiff > 300? This seems relatively straightforward with rolling window functions:

    First some imports

    from pyspark.sql.window import Window
    import pyspark.sql.functions as func
    

    Then setting windows, I assumed you would partition by userid

    w = Window.partitionBy("userid").orderBy("eventtime")
    

    Then figuring out what subgroup each observation falls into, by first marking the first member of each group, then summing the column.

    indicator = (TimeDiff > 300).cast("integer")
    subgroup = func.sum(indicator).over(w).alias("subgroup")
    

    Then some aggregation functions and you should be done

    DF = DF.select("*", subgroup)\
    .groupBy("subgroup")\
    .agg(
        func.min("eventtime").alias("start_time"),
        func.max("eventtime").alias("end_time"),
        func.count(func.lit(1)).alias("events")
    )
    
    0 讨论(0)
  • 2021-02-09 00:07

    You'll need one extra window function and a groupby to achieve this. What we want is for every line with timeDiff greater than 300 to be the end of a group and the start of a new one. Aku's solution should work, only the indicators mark the start of a group instead of the end. To change this you'll have to do a cumulative sum up to n-1 instead of n (n being your current line):

    w = Window.partitionBy("userid").orderBy("eventtime")
    DF = DF.withColumn("indicator", (DF.timeDiff > 300).cast("int"))
    DF = DF.withColumn("subgroup", func.sum("indicator").over(w) - func.col("indicator"))
    DF = DF.groupBy("subgroup").agg(
        func.min("eventtime").alias("start_time"), 
        func.max("eventtime").alias("end_time"),
        func.count("*").alias("events")
     )
    
    +--------+-------------------+-------------------+------+
    |subgroup|         start_time|           end_time|events|
    +--------+-------------------+-------------------+------+
    |       0|2017-06-04 03:00:00|2017-06-04 03:07:00|     6|
    |       1|2017-06-04 03:14:00|2017-06-04 03:15:00|     2|
    |       2|2017-06-04 03:34:00|2017-06-04 03:34:00|     1|
    |       3|2017-06-04 03:53:00|2017-06-04 03:53:00|     1|
    +--------+-------------------+-------------------+------+
    

    It seems that you also filter out lines with only one event, hence:

    DF = DF.filter("events != 1")
    
    +--------+-------------------+-------------------+------+
    |subgroup|         start_time|           end_time|events|
    +--------+-------------------+-------------------+------+
    |       0|2017-06-04 03:00:00|2017-06-04 03:07:00|     6|
    |       1|2017-06-04 03:14:00|2017-06-04 03:15:00|     2|
    +--------+-------------------+-------------------+------+
    
    0 讨论(0)
提交回复
热议问题