grouping consecutive rows in PySpark Dataframe

后端 未结 2 1329
青春惊慌失措
青春惊慌失措 2021-01-18 03:16

I have the following example Spark DataFrame:

rdd = sc.parallelize([(1,\"19:00:00\", \"19:30:00\", 30), (1,\"19:30:00\", \"19:40:00\", 10),(1,\"19:40:00\", \         


        
相关标签:
2条回答
  • 2021-01-18 03:24

    Here is a working solution derived from Pault's answer:

    Create the Dataframe:

    rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
    
    df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
    
    df.show()
    
    +-------+----------+--------+--------+
    |user_id|start_time|end_time|duration|
    +-------+----------+--------+--------+
    |      1|  19:00:00|19:30:00|      30|
    |      1|  19:30:00|19:40:00|      10|
    |      1|  19:40:00|19:43:00|       3|
    |      1|  20:05:00|20:15:00|      10|
    |      1|  20:15:00|20:35:00|      20|
    +-------+----------+--------+--------+
    

    Create an indicator column that indicates whenever the time has changed, and use cumulative sum to give each group a unique id:

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    w1 =  Window.partitionBy('user_id').orderBy('start_time')
    df = df.withColumn(
            "indicator",
            (f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
        )\
        .fillna(
            0,
            subset=[ "indicator"]
        )\
        .withColumn(
            "group",
            f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
        )
    df.show()
    
    +-------+----------+--------+--------+---------+-----+
    |user_id|start_time|end_time|duration|indicator|group|
    +-------+----------+--------+--------+---------+-----+
    |      1|  19:00:00|19:30:00|      30|        0|    0|
    |      1|  19:30:00|19:40:00|      10|        0|    0|
    |      1|  19:40:00|19:43:00|       3|        0|    0|
    |      1|  20:05:00|20:15:00|      10|        1|    1|
    |      1|  20:15:00|20:35:00|      20|        0|    1|
    +-------+----------+--------+--------+---------+-----+
    

    Now GroupBy on user id and the group variable.

    +-------+----------+--------+--------+
    |user_id|start_time|end_time|duration|
    +-------+----------+--------+--------+
    |      1|  19:00:00|19:43:00|      43|
    |      1|  20:05:00|20:35:00|      30|
    +-------+----------+--------+--------+
    
    0 讨论(0)
  • 2021-01-18 03:48

    Here's one approach:

    Gather together rows into groups where a group is a set of rows with the same user_id that are consecutive (start_time matches previous end_time). Then you can use this group to do your aggregation.

    A way to get here is by creating intermediate indicator columns to tell you if the user has changed or the time is not consecutive. Then perform a cumulative sum over the indicator column to create the group.

    For example:

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    w1 = Window.orderBy("start_time")
    df = df.withColumn(
            "userChange",
            (f.col("user_id") != f.lag("user_id").over(w1)).cast("int")
        )\
        .withColumn(
            "timeChange",
            (f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
        )\
        .fillna(
            0,
            subset=["userChange", "timeChange"]
        )\
        .withColumn(
            "indicator",
            (~((f.col("userChange") == 0) & (f.col("timeChange")==0))).cast("int")
        )\
        .withColumn(
            "group",
            f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
        )
    df.show()
    #+-------+----------+--------+--------+----------+----------+---------+-----+
    #|user_id|start_time|end_time|duration|userChange|timeChange|indicator|group|
    #+-------+----------+--------+--------+----------+----------+---------+-----+
    #|      1|  19:00:00|19:30:00|      30|         0|         0|        0|    0|
    #|      1|  19:30:00|19:40:00|      10|         0|         0|        0|    0|
    #|      1|  19:40:00|19:43:00|       3|         0|         0|        0|    0|
    #|      2|  20:00:00|20:10:00|      10|         1|         1|        1|    1|
    #|      1|  20:05:00|20:15:00|      10|         1|         1|        1|    2|
    #|      1|  20:15:00|20:35:00|      20|         0|         0|        0|    2|
    #+-------+----------+--------+--------+----------+----------+---------+-----+
    

    Now that we have the group column, we can aggregate as follows to get the desired result:

    df.groupBy("user_id", "group")\
        .agg(
            f.min("start_time").alias("start_time"),
            f.max("end_time").alias("end_time"),
            f.sum("duration").alias("duration")
        )\
        .drop("group")\
        .show()
    #+-------+----------+--------+--------+
    #|user_id|start_time|end_time|duration|
    #+-------+----------+--------+--------+
    #|      1|  19:00:00|19:43:00|      43|
    #|      1|  20:05:00|20:35:00|      30|
    #|      2|  20:00:00|20:10:00|      10|
    #+-------+----------+--------+--------+
    
    0 讨论(0)
提交回复
热议问题