Comparing two array columns in Scala Spark

♀尐吖头ヾ 提交于 2019-12-24 10:36:45

问题


I have a dataframe of format given below.

movieId1 | genreList1              | genreList2
--------------------------------------------------
1        |[Adventure,Comedy]       |[Adventure]
2        |[Animation,Drama,War]    |[War,Drama]
3        |[Adventure,Drama]        |[Drama,War]

and trying to create another flag column which shows whether genreList2 is a subset of genreList1

movieId1 | genreList1              | genreList2        | Flag
---------------------------------------------------------------
1        |[Adventure,Comedy]       | [Adventure]       |1
2        |[Animation,Drama,War]    | [War,Drama]       |1
3        |[Adventure,Drama]        | [Drama,War]       |0

I have tried this

def intersect_check(a: Array[String], b: Array[String]): Int = {
  if (b.sameElements(a.intersect(b))) { return 1 } 
  else { return 2 }
}

def intersect_check_udf =
  udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))

data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))

But this throws org.apache.spark.SparkException: Failed to execute user defined function. Error. Any ideas on how to resolve this. P.S.: The above function (intersect_check) works for Arrays.


回答1:


We can define an udf that calculates the length of the intersection between the two Array columns and checks whether it is equal to the length of the second column. If so, the second array is a subset of the first one.

Also, the inputs of your udf need to be class WrappedArray[String], not Array[String] :

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

val same_elements = udf { (a: WrappedArray[String], 
                           b: WrappedArray[String]) => 
  if (a.intersect(b).length == b.length){ 1 }else{ 0 }  
}

df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
  .show(truncate = false)
+--------+-----------------------+------------+----+
|movieId1|genreList1             |genreList2  |test|
+--------+-----------------------+------------+----+
|1       |[Adventure, Comedy]    |[Adventure] |1   |
|2       |[Animation, Drama, War]|[War, Drama]|1   |
|3       |[Adventure, Drama]     |[Drama, War]|0   |
+--------+-----------------------+------------+----+

Data

val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
              (2,Array("Animation","Drama","War"), Array("War","Drama")),
              (3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")



回答2:


Here is the solution converting using subsetOf

  val spark =
    SparkSession.builder().master("local").appName("test").getOrCreate()

  import spark.implicits._

  val data = spark.sparkContext.parallelize(
  Seq(
    (1,Array("Adventure","Comedy"),Array("Adventure")),
  (2,Array("Animation","Drama","War"),Array("War","Drama")),
  (3,Array("Adventure","Drama"),Array("Drama","War"))
  )).toDF("movieId1", "genreList1", "genreList2")


  val subsetOf = udf((col1: Seq[String], col2: Seq[String]) => {
    if (col2.toSet.subsetOf(col1.toSet)) 1 else 0
  })

  data.withColumn("flag", subsetOf(data("genreList1"), data("genreList2"))).show()

Hope this helps!



来源:https://stackoverflow.com/questions/44158623/comparing-two-array-columns-in-scala-spark

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!