Spark Advanced Window with dynamic last

前端 未结 4 1940
星月不相逢
星月不相逢 2021-02-09 11:14

Problem: Given a time series data which is a clickstream of user activity is stored in hive, ask is to enrich the data with session id using spark.

Session Definition

相关标签:
4条回答
  • 2021-02-09 11:50

    Not a straight forward problem to solve, but here's one approach:

    1. Use Window lag timestamp difference to identify sessions (with 0 = start of a session) per user for rule #1
    2. Group the dataset to assemble the timestamp diff list per user
    3. Process via a UDF the timestamp diff list to identify sessions for rule #2 and create all session ids per user
    4. Expand the grouped dataset via Spark's explode

    Sample code below:

    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.expressions.Window
    import spark.implicits._
    
    val userActivity = Seq(
      ("2018-01-01 11:00:00", "u1"),
      ("2018-01-01 12:10:00", "u1"),
      ("2018-01-01 13:00:00", "u1"),
      ("2018-01-01 13:50:00", "u1"),
      ("2018-01-01 14:40:00", "u1"),
      ("2018-01-01 15:30:00", "u1"),
      ("2018-01-01 16:20:00", "u1"),
      ("2018-01-01 16:50:00", "u1"),
      ("2018-01-01 11:00:00", "u2"),
      ("2018-01-02 11:00:00", "u2")
    ).toDF("click_time", "user_id")
    
    def clickSessList(tmo: Long) = udf{ (uid: String, clickList: Seq[String], tsList: Seq[Long]) =>
      def sid(n: Long) = s"$uid-$n"
    
      val sessList = tsList.foldLeft( (List[String](), 0L, 0L) ){ case ((ls, j, k), i) =>
        if (i == 0 || j + i >= tmo) (sid(k + 1) :: ls, 0L, k + 1) else
           (sid(k) :: ls, j + i, k)
      }._1.reverse
    
      clickList zip sessList
    }
    

    Note that the accumulator for foldLeft in the UDF is a Tuple of (ls, j, k), where:

    • ls is the list of formatted session ids to be returned
    • j and k are for carrying over the conditionally changing timestamp value and session id number, respectively, to the next iteration

    Step 1:

    val tmo1: Long = 60 * 60
    val tmo2: Long = 2 * 60 * 60
    
    val win1 = Window.partitionBy("user_id").orderBy("click_time")
    
    val df1 = userActivity.
      withColumn("ts_diff", unix_timestamp($"click_time") - unix_timestamp(
        lag($"click_time", 1).over(win1))
      ).
      withColumn("ts_diff", when(row_number.over(win1) === 1 || $"ts_diff" >= tmo1, 0L).
        otherwise($"ts_diff")
      )
    
    df1.show
    // +-------------------+-------+-------+
    // |         click_time|user_id|ts_diff|
    // +-------------------+-------+-------+
    // |2018-01-01 11:00:00|     u1|      0|
    // |2018-01-01 12:10:00|     u1|      0|
    // |2018-01-01 13:00:00|     u1|   3000|
    // |2018-01-01 13:50:00|     u1|   3000|
    // |2018-01-01 14:40:00|     u1|   3000|
    // |2018-01-01 15:30:00|     u1|   3000|
    // |2018-01-01 16:20:00|     u1|   3000|
    // |2018-01-01 16:50:00|     u1|   1800|
    // |2018-01-01 11:00:00|     u2|      0|
    // |2018-01-02 11:00:00|     u2|      0|
    // +-------------------+-------+-------+
    

    Steps 2-4:

    val df2 = df1.
      groupBy("user_id").agg(
        collect_list($"click_time").as("click_list"), collect_list($"ts_diff").as("ts_list")
      ).
      withColumn("click_sess_id",
        explode(clickSessList(tmo2)($"user_id", $"click_list", $"ts_list"))
      ).
      select($"user_id", $"click_sess_id._1".as("click_time"), $"click_sess_id._2".as("sess_id"))
    
    df2.show
    // +-------+-------------------+-------+
    // |user_id|click_time         |sess_id|
    // +-------+-------------------+-------+
    // |u1     |2018-01-01 11:00:00|u1-1   |
    // |u1     |2018-01-01 12:10:00|u1-2   |
    // |u1     |2018-01-01 13:00:00|u1-2   |
    // |u1     |2018-01-01 13:50:00|u1-2   |
    // |u1     |2018-01-01 14:40:00|u1-3   |
    // |u1     |2018-01-01 15:30:00|u1-3   |
    // |u1     |2018-01-01 16:20:00|u1-3   |
    // |u1     |2018-01-01 16:50:00|u1-4   |
    // |u2     |2018-01-01 11:00:00|u2-1   |
    // |u2     |2018-01-02 11:00:00|u2-2   |
    // +-------+-------------------+-------+
    

    Also note that click_time is "passed thru" in steps 2-4 so as to be included in the final dataset.

    0 讨论(0)
  • 2021-02-09 11:51

    Though the answer provided by Leo Works perfectly I feel its a complicated approach to Solve the problem by using Collect and Explode functions.This can be solved using Spark's Way by using UDAF to make it feasible for modifications in the near future as well.Please take a look into a solution on similar lines below

    scala> //Importing Packages
    
    scala> import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.expressions.Window
    
    scala> import org.apache.spark.sql.functions._
    import org.apache.spark.sql.functions._
    
    scala> import org.apache.spark.sql.types._
    import org.apache.spark.sql.types._
    
    scala> // CREATE UDAF To Calculate total session duration Based on SessionIncativeFlag and Current Session Duration
    
    scala> import org.apache.spark.sql.expressions.MutableAggregationBuffer
    import org.apache.spark.sql.expressions.MutableAggregationBuffer
    
    scala> import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    
    scala> import org.apache.spark.sql.Row
    import org.apache.spark.sql.Row
    
    scala> import org.apache.spark.sql.types._
    import org.apache.spark.sql.types._
    
    scala>
    
    scala> class TotalSessionDuration extends UserDefinedAggregateFunction {
         |   // This is the input fields for your aggregate function.
         |   override def inputSchema: org.apache.spark.sql.types.StructType =
         |     StructType(
         |       StructField("sessiondur", LongType) :: StructField(
         |         "inactivityInd",
         |         IntegerType
         |       ) :: Nil
         |     )
         |
         |   // This is the internal fields you keep for computing your aggregate.
         |   override def bufferSchema: StructType = StructType(
         |     StructField("sessionSum", LongType) :: Nil
         |   )
         |
         |   // This is the output type of your aggregatation function.
         |   override def dataType: DataType = LongType
         |
         |   override def deterministic: Boolean = true
         |
         |   // This is the initial value for your buffer schema.
         |   override def initialize(buffer: MutableAggregationBuffer): Unit = {
         |     buffer(0) = 0L
         |   }
         |
         |   // This is how to update your buffer schema given an input.
         |   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
         |     if (input.getAs[Int](1) == 1)
         |       buffer(0) = 0L
         |     else if (buffer.getAs[Long](0) >= 7200L)
         |       buffer(0) = input.getAs[Long](0)
         |     else
         |       buffer(0) = buffer.getAs[Long](0) + input.getAs[Long](0)
         |   }
         |
         |   // This is how to merge two objects with the bufferSchema type.
         |   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
         |     if (buffer2.getAs[Int](1) == 1)
         |       buffer1(0) = 0L
         |     else if (buffer2.getAs[Long](0) >= 7200)
         |       buffer1(0) = buffer2.getAs[Long](0)
         |     else
         |       buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
         |   }
         |   // This is where you output the final value, given the final value of your bufferSchema.
         |   override def evaluate(buffer: Row): Any = {
         |     buffer.getLong(0)
         |   }
         | }
    defined class TotalSessionDuration
    
    scala> //Create handle for using the UDAD Defined above
    
    scala> val sessionSum=spark.udf.register("sessionSum", new TotalSessionDuration)
    sessionSum: org.apache.spark.sql.expressions.UserDefinedAggregateFunction = TotalSessionDuration@64a9719a
    
    scala> //Create Session Dataframe
    
    scala> val clickstream = Seq(
         |   ("2018-01-01T11:00:00Z", "u1"),
         |   ("2018-01-01T12:10:00Z", "u1"),
         |   ("2018-01-01T13:00:00Z", "u1"),
         |   ("2018-01-01T13:50:00Z", "u1"),
         |   ("2018-01-01T14:40:00Z", "u1"),
         |   ("2018-01-01T15:30:00Z", "u1"),
         |   ("2018-01-01T16:20:00Z", "u1"),
         |   ("2018-01-01T16:50:00Z", "u1"),
         |   ("2018-01-01T11:00:00Z", "u2"),
         |   ("2018-01-02T11:00:00Z", "u2")
         | ).toDF("timestamp", "userid").withColumn("curr_timestamp",unix_timestamp($"timestamp", "yyyy-MM-dd'T'HH:mm:ss'Z'").cast(TimestampType)).drop("timestamp")
    clickstream: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp]
    
    scala>
    
    scala> clickstream.show(false)
    +------+-------------------+
    |userid|curr_timestamp     |
    +------+-------------------+
    |u1    |2018-01-01 11:00:00|
    |u1    |2018-01-01 12:10:00|
    |u1    |2018-01-01 13:00:00|
    |u1    |2018-01-01 13:50:00|
    |u1    |2018-01-01 14:40:00|
    |u1    |2018-01-01 15:30:00|
    |u1    |2018-01-01 16:20:00|
    |u1    |2018-01-01 16:50:00|
    |u2    |2018-01-01 11:00:00|
    |u2    |2018-01-02 11:00:00|
    +------+-------------------+
    
    
    scala> //Generate column SEF with values 0 or 1 depending on whether difference between current and previous activity time is greater than 1 hour=3600 sec
    
    scala>
    
    scala> //Window on Current Timestamp when last activity took place
    
    scala> val windowOnTs = Window.partitionBy("userid").orderBy("curr_timestamp")
    windowOnTs: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@41dabe47
    
    scala> //Create Lag Expression to find previous timestamp for the User
    
    scala> val lagOnTS = lag(col("curr_timestamp"), 1).over(windowOnTs)
    lagOnTS: org.apache.spark.sql.Column = lag(curr_timestamp, 1, NULL) OVER (PARTITION BY userid ORDER BY curr_timestamp ASC NULLS FIRST unspecifiedframe$())
    
    scala> //Compute Timestamp for previous activity and subtract the same from Timestamp for current activity to get difference between 2 activities
    
    scala> val diff_secs_col = col("curr_timestamp").cast("long") - col("prev_timestamp").cast("long")
    diff_secs_col: org.apache.spark.sql.Column = (CAST(curr_timestamp AS BIGINT) - CAST(prev_timestamp AS BIGINT))
    
    scala> val UserActWindowed=clickstream.withColumn("prev_timestamp", lagOnTS).withColumn("last_session_activity_after", diff_secs_col ).na.fill(0, Array("last_session_activity_after"))
    UserActWindowed: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp ... 2 more fields]
    
    scala> //Generate Flag Column SEF (Session Expiry Flag) to indicate Session Has Expired due to inactivity for more than 1 hour
    
    scala> val UserSessionFlagWhenInactive=UserActWindowed.withColumn("SEF",when(col("last_session_activity_after")>3600, 1).otherwise(0)).withColumn("tempsessid",sum(col("SEF"))  over windowOnTs)
    UserSessionFlagWhenInactive: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp ... 4 more fields]
    
    scala> UserSessionFlagWhenInactive.show(false)
    +------+-------------------+-------------------+---------------------------+---+----------+
    |userid|curr_timestamp     |prev_timestamp     |last_session_activity_after|SEF|tempsessid|
    +------+-------------------+-------------------+---------------------------+---+----------+
    |u1    |2018-01-01 11:00:00|null               |0                          |0  |0         |
    |u1    |2018-01-01 12:10:00|2018-01-01 11:00:00|4200                       |1  |1         |
    |u1    |2018-01-01 13:00:00|2018-01-01 12:10:00|3000                       |0  |1         |
    |u1    |2018-01-01 13:50:00|2018-01-01 13:00:00|3000                       |0  |1         |
    |u1    |2018-01-01 14:40:00|2018-01-01 13:50:00|3000                       |0  |1         |
    |u1    |2018-01-01 15:30:00|2018-01-01 14:40:00|3000                       |0  |1         |
    |u1    |2018-01-01 16:20:00|2018-01-01 15:30:00|3000                       |0  |1         |
    |u1    |2018-01-01 16:50:00|2018-01-01 16:20:00|1800                       |0  |1         |
    |u2    |2018-01-01 11:00:00|null               |0                          |0  |0         |
    |u2    |2018-01-02 11:00:00|2018-01-01 11:00:00|86400                      |1  |1         |
    +------+-------------------+-------------------+---------------------------+---+----------+
    
    
    scala> //Compute Total session duration using the UDAF TotalSessionDuration such that :
    
    scala> //(i)counter will be rest to 0 if SEF is set to 1
    
    scala> //(ii)or set it to current session duration if session exceeds 2 hours
    
    scala> //(iii)If both of them are inapplicable accumulate the sum
    
    scala> val UserSessionDur=UserSessionFlagWhenInactive.withColumn("sessionSum",sessionSum(col("last_session_activity_after"),col("SEF"))  over windowOnTs)
    UserSessionDur: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp ... 5 more fields]
    
    scala> //Generate Session Marker if SEF is 1 or sessionSum Exceeds 2 hours(7200) seconds
    
    scala> val UserNewSessionMarker=UserSessionDur.withColumn("SessionFlagChangeIndicator",when(col("SEF")===1 || col("sessionSum")>7200, 1).otherwise(0) )
    UserNewSessionMarker: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp ... 6 more fields]
    
    scala> //Create New Session ID based on the marker
    
    scala> val computeSessionId=UserNewSessionMarker.drop("SEF","tempsessid","sessionSum").withColumn("sessid",concat(col("userid"),lit("-"),(sum(col("SessionFlagChangeIndicator"))  over windowOnTs)+1.toLong))
    computeSessionId: org.apache.spark.sql.DataFrame = [userid: string, curr_timestamp: timestamp ... 4 more fields]
    
    scala> computeSessionId.show(false)
    +------+-------------------+-------------------+---------------------------+--------------------------+------+
    |userid|curr_timestamp     |prev_timestamp     |last_session_activity_after|SessionFlagChangeIndicator|sessid|
    +------+-------------------+-------------------+---------------------------+--------------------------+------+
    |u1    |2018-01-01 11:00:00|null               |0                          |0                         |u1-1  |
    |u1    |2018-01-01 12:10:00|2018-01-01 11:00:00|4200                       |1                         |u1-2  |
    |u1    |2018-01-01 13:00:00|2018-01-01 12:10:00|3000                       |0                         |u1-2  |
    |u1    |2018-01-01 13:50:00|2018-01-01 13:00:00|3000                       |0                         |u1-2  |
    |u1    |2018-01-01 14:40:00|2018-01-01 13:50:00|3000                       |1                         |u1-3  |
    |u1    |2018-01-01 15:30:00|2018-01-01 14:40:00|3000                       |0                         |u1-3  |
    |u1    |2018-01-01 16:20:00|2018-01-01 15:30:00|3000                       |0                         |u1-3  |
    |u1    |2018-01-01 16:50:00|2018-01-01 16:20:00|1800                       |1                         |u1-4  |
    |u2    |2018-01-01 11:00:00|null               |0                          |0                         |u2-1  |
    |u2    |2018-01-02 11:00:00|2018-01-01 11:00:00|86400                      |1                         |u2-2  |
    +------+-------------------+-------------------+---------------------------+--------------------------+------+
    
    0 讨论(0)
  • 2021-02-09 11:53

    Complete solution

    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._
    import scala.collection.mutable.ListBuffer
    import scala.util.control._
    import spark.sqlContext.implicits._
    import java.sql.Timestamp
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types._
    
    
    val interimSessionThreshold=60
    val totalSessionTimeThreshold=120
    
    val sparkSession = SparkSession.builder.master("local").appName("Window Function").getOrCreate()
    
    val clickDF = sparkSession.createDataFrame(Seq(
          ("2018-01-01T11:00:00Z","u1"),
            ("2018-01-01T12:10:00Z","u1"),
            ("2018-01-01T13:00:00Z","u1"),
            ("2018-01-01T13:50:00Z","u1"),
            ("2018-01-01T14:40:00Z","u1"),
            ("2018-01-01T15:30:00Z","u1"),
            ("2018-01-01T16:20:00Z","u1"),
            ("2018-01-01T16:50:00Z","u1"),
            ("2018-01-01T11:00:00Z","u2"),
            ("2018-01-02T11:00:00Z","u2")
        )).toDF("clickTime","user")
    
    val newDF=clickDF.withColumn("clickTimestamp",unix_timestamp($"clickTime", "yyyy-MM-dd'T'HH:mm:ss'Z'").cast(TimestampType).as("timestamp")).drop($"clickTime")  
    
    val partitionWindow = Window.partitionBy($"user").orderBy($"clickTimestamp".asc)
    
    val lagTest = lag($"clickTimestamp", 1, "0000-00-00 00:00:00").over(partitionWindow)
    val df_test=newDF.select($"*", ((unix_timestamp($"clickTimestamp")-unix_timestamp(lagTest))/60D cast "int") as "diff_val_with_previous")
    
    
    val distinctUser=df_test.select($"user").distinct.as[String].collect.toList
    
    val rankTest = rank().over(partitionWindow)
    val ddf = df_test.select($"*", rankTest as "rank")
    
    case class finalClick(User:String,clickTime:Timestamp,session:String)
    
    val rowList: ListBuffer[finalClick] = new ListBuffer()
    
    
    distinctUser.foreach{x =>{
        val tempDf= ddf.filter($"user" === x)
        var cumulDiff:Int=0
        var session_index=1
        var startBatch=true
        var dp=0
        val len = tempDf.count.toInt
        for(i <- 1 until len+1){
          val r = tempDf.filter($"rank" === i).head()
          dp = r.getAs[Int]("diff_val_with_previous")
          cumulDiff += dp
          if(dp <= interimSessionThreshold && cumulDiff <= totalSessionTimeThreshold){
            startBatch=false
            rowList += finalClick(r.getAs[String]("user"),r.getAs[Timestamp]("clickTimestamp"),r.getAs[String]("user")+session_index)
          }
          else{
            session_index+=1
            cumulDiff = 0
            startBatch=true
            dp=0
            rowList += finalClick(r.getAs[String]("user"),r.getAs[Timestamp]("clickTimestamp"),r.getAs[String]("user")+session_index)
          }
        } 
    }}
    
    
    val dataFrame = sc.parallelize(rowList.toList).toDF("user","clickTimestamp","session")
    
    dataFrame.show
    
    +----+-------------------+-------+
    |user|     clickTimestamp|session|
    +----+-------------------+-------+
    |  u1|2018-01-01 11:00:00|    u11|
    |  u1|2018-01-01 12:10:00|    u12|
    |  u1|2018-01-01 13:00:00|    u12|
    |  u1|2018-01-01 13:50:00|    u12|
    |  u1|2018-01-01 14:40:00|    u13|
    |  u1|2018-01-01 15:30:00|    u13|
    |  u1|2018-01-01 16:20:00|    u13|
    |  u1|2018-01-01 16:50:00|    u14|
    |  u2|2018-01-01 11:00:00|    u21|
    |  u2|2018-01-02 11:00:00|    u22|
    +----+-------------------+-------+
    
    
    
    
    0 讨论(0)
  • 2021-02-09 11:57

    -----Solution without using explode----.

    `In my point of view explode is heavy process and inorder to apply you have taken groupby and collect_list.` 
    
    
    
    `
        import pyspark.sql.functions  as f
         from pyspark.sql.window import Window
        streaming_data=[("U1","2019-01-01T11:00:00Z") , 
        ("U1","2019-01-01T11:15:00Z") , 
        ("U1","2019-01-01T12:00:00Z") , 
        ("U1","2019-01-01T12:20:00Z") , 
        ("U1","2019-01-01T15:00:00Z") , 
        ("U2","2019-01-01T11:00:00Z") , 
        ("U2","2019-01-02T11:00:00Z") , 
        ("U2","2019-01-02T11:25:00Z") , 
        ("U2","2019-01-02T11:50:00Z") , 
        ("U2","2019-01-02T12:15:00Z") , 
        ("U2","2019-01-02T12:40:00Z") , 
        ("U2","2019-01-02T13:05:00Z") , 
        ("U2","2019-01-02T13:20:00Z") ]
        schema=("UserId","Click_Time")
        window_spec=Window.partitionBy("UserId").orderBy("Click_Time")
        df_stream=spark.createDataFrame(streaming_data,schema)
        df_stream=df_stream.withColumn("Click_Time",df_stream["Click_Time"].cast("timestamp"))
        
        
        df_stream=df_stream\
        .withColumn("time_diff",
                    (f.unix_timestamp("Click_Time")-f.unix_timestamp(f.lag(f.col("Click_Time"),1).over(window_spec)))/(60*60)).na.fill(0)
        
        df_stream=df_stream\
        .withColumn("cond_",f.when(f.col("time_diff")>1,1).otherwise(0))
        df_stream=df_stream.withColumn("temp_session",f.sum(f.col("cond_")).over(window_spec))
        new_spec=Window.partitionBy("UserId","temp_session").orderBy("Click_Time")
        df_stream=df_stream.withColumn("first_time_click",f.first(f.col("Click_Time")).over(new_spec))\
                           .withColumn("final_session_groups",\
                                       f.when((f.unix_timestamp(f.col("Click_Time"))-f.unix_timestamp(f.col("first_time_click")))/(2*60*60)>1,1)\
                                       .otherwise(0)).drop("first_time_click","cond_")
        df_stream=df_stream.withColumn("final_session",df_stream["temp_session"]+df_stream["final_session_groups"]+1)\
        .drop("temp_session","final_session_groups","time_diff")
        df_stream=df_stream.withColumn("session_id",f.concat(f.col("UserId"),f.lit(" session_val----->"),f.col("final_session")))
        
    df_stream.show(20,0) `
    

    ---Steps taken to solve ---

    ` 1.first find out those clickstream which are clicked less than one hour and find the continuous groups.

    2.then find out the click streams based on the 2hrs condition and make the continuous groups.

    3.Sum of these two above continuous groups and add +1 to populate the final_session column at the end of algo and do concat as per your requirement to show the session_id.`

    result will be looking like this

    `+------+---------------------+-------------+---------------------+
    |UserId|Click_Time           |final_session|session_id           |
    +------+---------------------+-------------+---------------------+
    |U2    |2019-01-01 11:00:00.0|1            |U2 session_val----->1|
    |U2    |2019-01-02 11:00:00.0|2            |U2 session_val----->2|
    |U2    |2019-01-02 11:25:00.0|2            |U2 session_val----->2|
    |U2    |2019-01-02 11:50:00.0|2            |U2 session_val----->2|
    |U2    |2019-01-02 12:15:00.0|2            |U2 session_val----->2|
    |U2    |2019-01-02 12:40:00.0|2            |U2 session_val----->2|
    |U2    |2019-01-02 13:05:00.0|3            |U2 session_val----->3|
    |U2    |2019-01-02 13:20:00.0|3            |U2 session_val----->3|
    |U1    |2019-01-01 11:00:00.0|1            |U1 session_val----->1|
    |U1    |2019-01-01 11:15:00.0|1            |U1 session_val----->1|
    |U1    |2019-01-01 12:00:00.0|2            |U1 session_val----->2|
    |U1    |2019-01-01 12:20:00.0|2            |U1 session_val----->2|
    |U1    |2019-01-01 15:00:00.0|3            |U1 session_val----->3|
    +------+---------------------+-------------+---------------------+  
    

    `

    0 讨论(0)
提交回复
热议问题