I have the following df:
+---+----+-----+
|sno|dept|color|
+---+----+-----+
| 1| fn| red|
| 2| fn| blue|
| 3| fn|green|
+---+----+-----+
<
You are conditionally updating the DataFrame if it satisfies a certain property. In this case the property is "the color column contains 'red'". The idiomatic way to express this is to filter with the desired predicate and then determine whether any rows satisfy it. There is no need for a join.
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.DataFrame
def makeAllRedIfAnyAreRed(df: DataFrame) = {
val containsRed = df.filter(df("color") === "red").count() > 0
if (containsRed) df.withColumn("color", lit("red")) else df
}