Converting row values into a column array in spark dataframe

假如想象 提交于 2020-01-01 19:36:09

问题


I am working on spark dataframes and I need to do a group by of a column and convert the column values of grouped rows into an array of elements as new column. Example :

Input:

employee | Address
------------------
Micheal  |  NY
Micheal  |  NJ

Output:

employee | Address
------------------
Micheal  | (NY,NJ)

Any help is highly appreciated.!


回答1:


Here is an alternate solution Where I have converted the dataframe to an rdd for the transformations and converted it back a dataFrame using sqlContext.createDataFrame()

Sample.json

{"employee":"Michale","Address":"NY"}
{"employee":"Michale","Address":"NJ"}
{"employee":"Sam","Address":"NY"}
{"employee":"Max","Address":"NJ"}

Spark Application

val df = sqlContext.read.json("sample.json")

// Printing the original Df
df.show()

//Defining the Schema for the aggregated DataFrame
val dataSchema = new StructType(
  Array(
    StructField("employee", StringType, nullable = true),
    StructField("Address", ArrayType(StringType, containsNull = true), nullable = true)
  )
)
// Converting the df to rdd and performing the groupBy operation
val aggregatedRdd: RDD[Row] = df.rdd.groupBy(r =>
          r.getAs[String]("employee")
        ).map(row =>
          // Mapping the Grouped Values to a new Row Object
          Row(row._1, row._2.map(_.getAs[String]("Address")).toArray)
        )

// Creating a DataFrame from the aggregatedRdd with the defined Schema (dataSchema)
val aggregatedDf = sqlContext.createDataFrame(aggregatedRdd, dataSchema)

// Printing the aggregated Df
aggregatedDf.show()

Output :

+-------+--------+---+
|Address|employee|num|
+-------+--------+---+
|     NY| Michale|  1|
|     NJ| Michale|  2|
|     NY|     Sam|  3|
|     NJ|     Max|  4|
+-------+--------+---+

+--------+--------+
|employee| Address|
+--------+--------+
|     Sam|    [NY]|
| Michale|[NY, NJ]|
|     Max|    [NJ]|
+--------+--------+



回答2:


If you're using Spark 2.0+, you can use collect_list or collect_set. Your query will be something like (assuming your dataframe is called input):

import org.apache.spark.sql.functions._

input.groupBy('employee).agg(collect_list('Address))

If you are ok with duplicates, use collect_list. If you're not ok with duplicates and only need unique items in the list, use collect_set.

Hope this helps!



来源:https://stackoverflow.com/questions/36332232/converting-row-values-into-a-column-array-in-spark-dataframe

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!