I have a DataFrame generated as follow:
df.groupBy($\"Hour\", $\"Category\")
.agg(sum($\"value\") as \"TotalValue\")
.sort($\"Hour\".asc, $\"TotalValue\"
A nice way of doing this with the dataframe api is using the argmax logic like so
val df = Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
df.groupBy($"Hour")
.agg(max(struct($"TotalValue", $"Category")).as("argmax"))
.select($"Hour", $"argmax.*").show
+----+----------+--------+
|Hour|TotalValue|Category|
+----+----------+--------+
| 1| 28.5| cat67|
| 3| 35.6| cat8|
| 2| 39.6| cat56|
| 0| 30.9| cat26|
+----+----------+--------+
For Spark 2.0.2 with grouping by multiple columns:
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
Here you can do like this -
val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
data.withColumnRenamed("_1","Hour").show
The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame
//df is the dataframe with Day, Category, TotalValue
implicit val dfEnc = RowEncoder(df.schema)
val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
We can use the rank() window function (where you would choose the rank = 1) rank just adds a number for every row of a group (in this case it would be the hour)
here's an example. ( from https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )
val dataset = spark.range(9).withColumn("bucket", 'id % 3)
import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)
scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
| 0| 0| 1|
| 3| 0| 2|
| 6| 0| 3|
| 1| 1| 1|
| 4| 1| 2|
| 7| 1| 3|
| 2| 2| 1|
| 5| 2| 2|
| 8| 2| 3|
+---+------+----+
Window functions:
Something like this should do the trick:
import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
This method will be inefficient in case of significant data skew.
Plain SQL aggregation followed by join
:
Alternatively you can join with aggregated data frame:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
Using ordering over structs
:
Neat, although not very well tested, trick which doesn't require joins or window functions:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
With DataSet API (Spark 1.6+, 2.0+):
Spark 1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0 or later:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed
output mode.
Don't use:
df.orderBy(...).groupBy(...).agg(first(...), ...)
It may seem to work (especially in the local
mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).
The same note applies to
df.orderBy(...).dropDuplicates(...)
which internally uses equivalent execution plan.