How to select the first row of each group?

前端 未结 8 793
心在旅途
心在旅途 2020-11-21 05:49

I have a DataFrame generated as follow:

df.groupBy($\"Hour\", $\"Category\")
  .agg(sum($\"value\") as \"TotalValue\")
  .sort($\"Hour\".asc, $\"TotalValue\"         


        
相关标签:
8条回答
  • 2020-11-21 06:29

    A nice way of doing this with the dataframe api is using the argmax logic like so

      val df = Seq(
        (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
        (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
        (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
        (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
    
      df.groupBy($"Hour")
        .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
        .select($"Hour", $"argmax.*").show
    
     +----+----------+--------+
     |Hour|TotalValue|Category|
     +----+----------+--------+
     |   1|      28.5|   cat67|
     |   3|      35.6|    cat8|
     |   2|      39.6|   cat56|
     |   0|      30.9|   cat26|
     +----+----------+--------+
    
    0 讨论(0)
  • 2020-11-21 06:30

    For Spark 2.0.2 with grouping by multiple columns:

    import org.apache.spark.sql.functions.row_number
    import org.apache.spark.sql.expressions.Window
    
    val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
    
    val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
    
    0 讨论(0)
  • 2020-11-21 06:34

    Here you can do like this -

       val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
    
    data.withColumnRenamed("_1","Hour").show
    
    0 讨论(0)
  • 2020-11-21 06:39

    The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.catalyst.encoders.RowEncoder
    import org.apache.spark.sql.DataFrame
    
    //df is the dataframe with Day, Category, TotalValue
    
    implicit val dfEnc = RowEncoder(df.schema)
    
    val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
    
    0 讨论(0)
  • 2020-11-21 06:41

    We can use the rank() window function (where you would choose the rank = 1) rank just adds a number for every row of a group (in this case it would be the hour)

    here's an example. ( from https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

    val dataset = spark.range(9).withColumn("bucket", 'id % 3)
    
    import org.apache.spark.sql.expressions.Window
    val byBucket = Window.partitionBy('bucket).orderBy('id)
    
    scala> dataset.withColumn("rank", rank over byBucket).show
    +---+------+----+
    | id|bucket|rank|
    +---+------+----+
    |  0|     0|   1|
    |  3|     0|   2|
    |  6|     0|   3|
    |  1|     1|   1|
    |  4|     1|   2|
    |  7|     1|   3|
    |  2|     2|   1|
    |  5|     2|   2|
    |  8|     2|   3|
    +---+------+----+
    
    0 讨论(0)
  • 2020-11-21 06:46

    Window functions:

    Something like this should do the trick:

    import org.apache.spark.sql.functions.{row_number, max, broadcast}
    import org.apache.spark.sql.expressions.Window
    
    val df = sc.parallelize(Seq(
      (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
      (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
      (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
      (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
    
    val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
    
    val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
    
    dfTop.show
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |   0|   cat26|      30.9|
    // |   1|   cat67|      28.5|
    // |   2|   cat56|      39.6|
    // |   3|    cat8|      35.6|
    // +----+--------+----------+
    

    This method will be inefficient in case of significant data skew.

    Plain SQL aggregation followed by join:

    Alternatively you can join with aggregated data frame:

    val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
    
    val dfTopByJoin = df.join(broadcast(dfMax),
        ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
      .drop("max_hour")
      .drop("max_value")
    
    dfTopByJoin.show
    
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |   0|   cat26|      30.9|
    // |   1|   cat67|      28.5|
    // |   2|   cat56|      39.6|
    // |   3|    cat8|      35.6|
    // +----+--------+----------+
    

    It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

    dfTopByJoin
      .groupBy($"hour")
      .agg(
        first("category").alias("category"),
        first("TotalValue").alias("TotalValue"))
    

    Using ordering over structs:

    Neat, although not very well tested, trick which doesn't require joins or window functions:

    val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
      .groupBy($"hour")
      .agg(max("vs").alias("vs"))
      .select($"Hour", $"vs.Category", $"vs.TotalValue")
    
    dfTop.show
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |   0|   cat26|      30.9|
    // |   1|   cat67|      28.5|
    // |   2|   cat56|      39.6|
    // |   3|    cat8|      35.6|
    // +----+--------+----------+
    

    With DataSet API (Spark 1.6+, 2.0+):

    Spark 1.6:

    case class Record(Hour: Integer, Category: String, TotalValue: Double)
    
    df.as[Record]
      .groupBy($"hour")
      .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
      .show
    
    // +---+--------------+
    // | _1|            _2|
    // +---+--------------+
    // |[0]|[0,cat26,30.9]|
    // |[1]|[1,cat67,28.5]|
    // |[2]|[2,cat56,39.6]|
    // |[3]| [3,cat8,35.6]|
    // +---+--------------+
    

    Spark 2.0 or later:

    df.as[Record]
      .groupByKey(_.Hour)
      .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
    

    The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

    Don't use:

    df.orderBy(...).groupBy(...).agg(first(...), ...)
    

    It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

    The same note applies to

    df.orderBy(...).dropDuplicates(...)
    

    which internally uses equivalent execution plan.

    0 讨论(0)
提交回复
热议问题