How do I efficiently create a new column in a DataFrame
that is a function of other rows in spark
?
The problem here is that I don't know a way to tell spark to compare the current row to the other rows in the Dataframe without first collecting the values into a list.
UDF is not an option here (you cannot reference distributed DataFrame
in udf
) Direct translation of your logic is Cartesian product and aggregate:
from pyspark.sql.functions import levenshtein, col
result = (spark_df.alias("l")
.crossJoin(spark_df.alias("r"))
.where(levenshtein("l.word", "r.word") < 2)
.where(col("l.word") != col("r.word"))
.groupBy("l.id", "l.word")
.count())
but in practice you should try to do something more efficient: Efficient string matching in Apache Spark
Depending on the problem, you should try to find other approximations to avoid full Cartesian product.
If you want to keep data without matches you can skip one filter:
(spark_df.alias("l")
.crossJoin(spark_df.alias("r"))
.where(levenshtein("l.word", "r.word") < 2)
.groupBy("l.id", "l.word")
.count()
.withColumn("count", col("count") - 1))
or (slower, but more generic), join back with reference:
(spark_df
.select("id", "word")
.distinct()
.join(result, ["id", "word"], "left")
.na.fill(0))