I have the following df:
+---+----+-----+
|sno|dept|color|
+---+----+-----+
| 1| fn| red|
| 2| fn| blue|
| 3| fn|green|
+---+----+-----+
<
Given:
val df = Seq(
(1, "fn", "red"),
(2, "fn", "blue"),
(3, "fn", "green"),
(4, "aa", "blue"),
(5, "aa", "green"),
(6, "bb", "red"),
(7, "bb", "red"),
(8, "aa", "blue")
).toDF("id", "fn", "color")
Do the calculation:
val redOrNot = df.groupBy("fn")
.agg(collect_set('color) as "values")
.withColumn("hasRed", array_contains('values, "red"))
// gives null for no option
val colorPicker = when('hasRed, "red")
val result = df.join(redOrNot, "fn")
.withColumn("resultColor", colorPicker)
.withColumn("color", coalesce('resultColor, 'color)) // skips nulls that leads to the answer
.select('id, 'fn, 'color)
The result
looks as follows (that seems to be an answer):
scala> result.show
+---+---+-----+
| id| fn|color|
+---+---+-----+
| 1| fn| red|
| 2| fn| red|
| 3| fn| red|
| 4| aa| blue|
| 5| aa|green|
| 6| bb| red|
| 7| bb| red|
| 8| aa| blue|
+---+---+-----+
You can chain when
operators and have a default value with otherwise
. Consult the scaladoc of when operator.
I think you could do something very similar (and perhaps more efficient) using windowed operators or user-defined aggregate functions (UDAF), but...well...don't currently know how to do it. Leaving the comment here to inspire others ;-)
p.s. Learnt a lot! Thanks for the idea!
As there could be few rows in filtered dataframe I'm adding solution with isin()
and .withColumn()
combination.
Sample DataFrame
val df = Seq(
(1, "fn", "red"),
(2, "fn", "blue"),
(3, "fn", "green"),
(4, "aa", "blue"),
(5, "aa", "green"),
(6, "bb", "red"),
(7, "bb", "red"),
(8, "aa", "blue")
).toDF("id", "dept", "color")
Now Let's pick only dept
s which have at least one red color
row and place it in broadcast
variable like below.
val depts = sc.broadcast(df.filter($"color" === "red").select(collect_set("dept")).first.getSeq[String](0)))
Update red color for filtered depts
records.
isin()
takes a vararg so convert list to vararg (depts.value:_*
)
//creating new column by giving diff name (clr) to see the diff
val result = df.withColumn("clr", when($"dept".isin(depts.value:_*),lit("red"))
.otherwise($"color"))
result.show()
+---+----+-----+-----+
| id|dept|color| clr|
+---+----+-----+-----+
| 1| fn| red| red|
| 2| fn| blue| red|
| 3| fn|green| red|
| 4| aa| blue| blue|
| 5| aa|green|green|
| 6| bb| red| red|
| 7| bb| red| red|
| 8| aa| blue| blue|
+---+----+-----+-----+
You are conditionally updating the DataFrame if it satisfies a certain property. In this case the property is "the color column contains 'red'". The idiomatic way to express this is to filter with the desired predicate and then determine whether any rows satisfy it. There is no need for a join.
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.DataFrame
def makeAllRedIfAnyAreRed(df: DataFrame) = {
val containsRed = df.filter(df("color") === "red").count() > 0
if (containsRed) df.withColumn("color", lit("red")) else df
}
Efficient solution which doesn't require expensive grouping:
// All groups with `red`
df.where($"color" === "red").select($"fn".alias("fn_")).distinct
// Join with input
.join(df.as("df"), $"fn" === $"fn_", "rightouter")
// Replace `color`
.withColumn("color", when($"fn_"isNull, $"color").otherwise(lit("red")))
.drop("fn_")
Spark 2.2.0: Sample Dataframe ( taken from above examples)
val df = Seq(
(1, "fn", "red"),
(2, "fn", "blue"),
(3, "fn", "green"),
(4, "aa", "blue"),
(5, "aa", "green"),
(6, "bb", "red"),
(7, "bb", "red"),
(8, "aa", "blue")
).toDF("id", "dept", "color")
created a UDF to perform the update by checking the condition.
val replace_val = udf((x: String,y:String) => if (Option(x).getOrElse("").equalsIgnoreCase("fn")&&(!y.equalsIgnoreCase("red"))) "red" else y)
val final_df = df.withColumn("color", replace_val($"dept",$"color"))
final_df.show()
output:
spark 1.6:
val conf = new SparkConf().setMaster("local").setAppName("My app")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// For implicit conversions like converting RDDs to DataFrames
val df = sc.parallelize(Seq(
(1, "fn", "red"),
(2, "fn", "blue"),
(3, "fn", "green"),
(4, "aa", "blue"),
(5, "aa", "green"),
(6, "bb", "red"),
(7, "bb", "red"),
(8, "aa", "blue")
) ).toDF("id","dept","color")
val replace_val = udf((x: String,y:String) => if (Option(x).getOrElse("").equalsIgnoreCase("fn")&&(!y.equalsIgnoreCase("red"))) "red" else y)
val final_df = df.withColumn("color", replace_val($"dept",$"color"))
final_df.show()