How to pivot Spark DataFrame?

匿名 (未验证) 提交于 2019-12-03 01:33:01

问题:

I am starting to use Spark Dataframes and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in python, but I can't find anything for the new Spark Dataframe.

I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. I anyone knows how to do this with built in functionality or suggestions for how to write something in Scala, it is greatly appreciated.

回答1:

As mentioned by @user2000823 Spark provides pivot function since version 1.6. General syntax looks as follows:

df   .groupBy(grouping_columns)   .pivot(pivot_column, [values])    .agg(aggregate_expressions) 

Usage examples using nycflights13 and csv format:

Python:

from pyspark.sql.functions import avg  flights = (sqlContext     .read     .format("csv")     .options(inferSchema="true", header="true")     .load("flights.csv")     .na.drop())  flights.registerTempTable("flights") sqlContext.cacheTable("flights")  gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay")  flights.count() ## 336776  %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop 

Scala:

val flights = sqlContext   .read   .format("csv")   .options(Map("inferSchema" -> "true", "header" -> "true"))   .load("flights.csv")  flights   .groupBy($"origin", $"dest", $"carrier")   .pivot("hour")   .agg(avg($"arr_delay")) 

Java:

import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.*;  Dataset df = spark.read().format("csv")         .option("inferSchema", "true")         .option("header", "true")         .load("flights.csv");  df.groupBy(col("origin"), col("dest"), col("carrier"))         .pivot("hour")         .agg(avg(col("arr_delay"))); 

R / SparkR:

library(magrittr)  flights %    groupBy("origin", "dest", "carrier") %>%    pivot("hour") %>%    agg(avg(column("arr_delay"))) 

R / sparklyr

library(dplyr)  flights % invoke("agg", expr, list()) }  flights %>%    sdf_pivot(origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay) 

Example data:

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00 

Performance considerations:

Generally speaking pivoting is an expensive operation.



回答2:

I overcame this by writing a for loop to dynamically create a SQL query. Say I have:

id  tag  value 1   US    50 1   UK    100 1   Can   125 2   US    75 2   UK    150 2   Can   175 

and I want:

id  US  UK   Can 1   50  100  125 2   75  150  175 

I can create a list with the value I want to pivot and then create a string containing the SQL query I need.

val countries = List("US", "UK", "Can") val numCountries = countries.length - 1  var query = "select *, " for (i 

I can create similar query to then do the aggregation. Not a very elegant solution but it works and is flexible for any list of values, which can also be passed in as an argument when your code is called.



回答3:

A pivot operator has been added to the Spark dataframe API, and is part of Spark 1.6.

See https://github.com/apache/spark/pull/7841 for details.



回答4:

I have solved a similar problem using dataframes with the following steps:

Create columns for all your countries, with 'value' as the value:

import org.apache.spark.sql.functions._ val countries = List("US", "UK", "Can") val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>   if(countryToCheck == countryInRow) value else 0 } val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) } val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value") 

Your dataframe 'dfWithCountries' will look like this:

+--+--+---+---+ |id|US| UK|Can| +--+--+---+---+ | 1|50|  0|  0| | 1| 0|100|  0| | 1| 0|  0|125| | 2|75|  0|  0| | 2| 0|150|  0| | 2| 0|  0|175| +--+--+---+---+ 

Now you can sum together all the values for your desired result:

dfWithCountries.groupBy("id").sum(countries: _*).show 

Result:

+--+-------+-------+--------+ |id|SUM(US)|SUM(UK)|SUM(Can)| +--+-------+-------+--------+ | 1|     50|    100|     125| | 2|     75|    150|     175| +--+-------+-------+--------+ 

It's not a very elegant solution though. I had to create a chain of functions to add in all the columns. Also if I have lots of countries, I will expand my temporary data set to a very wide set with lots of zeroes.



回答5:

Initially i adopted Al M's solution. Later took the same thought and rewrote this function as a transpose function.

This method transposes any df rows to columns of any data-format with using key and value column

for input csv

id,tag,value 1,US,50a 1,UK,100 1,Can,125 2,US,75 2,UK,150 2,Can,175 

ouput

+--+---+---+---+ |id| UK| US|Can| +--+---+---+---+ | 2|150| 75|175| | 1|100|50a|125| +--+---+---+---+ 

transpose method :

def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = {  val distinctCols =   df.select(key).distinct.map { r => r(0) }.collect().toList  val rdd = df.map { row => (compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] }, scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any])) } val pairRdd = rdd.reduceByKey(_ ++ _) val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols)) hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols)))  }  private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = { val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) } val array = r._1 ++ cols Row(array: _*) }  private  def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = { val idSchema = idCols.map { idCol => srcSchema.apply(idCol) } val colSchema = srcSchema.apply(distinctCols._1) val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) } StructType(idSchema ++ colsSchema) } 

main snippet

import java.util.Date import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructField   ... ... def main(args: Array[String]): Unit = {      val sc = new SparkContext(conf)     val sqlContext = new org.apache.spark.sql.SQLContext(sc)     val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true")     .load("data.csv")     dfdata1.show()       val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value")     dfOutput.show  } 


回答6:

There is simple and elegant solution.

scala> spark.sql("select * from k_tags limit 10").show() +---------------+-------------+------+ |           imsi|         name| value| +---------------+-------------+------+ |246021000000000|          age|    37| |246021000000000|       gender|Female| |246021000000000|         arpu|    22| |246021000000000|   DeviceType| Phone| |246021000000000|DataAllowance|   6GB| +---------------+-------------+------+  scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show() +---------------+-------------+----------+---+----+------+ |           imsi|DataAllowance|DeviceType|age|arpu|gender| +---------------+-------------+----------+---+----+------+ |246021000000000|          6GB|     Phone| 37|  22|Female| |246021000000001|          1GB|     Phone| 72|  10|  Male| +---------------+-------------+----------+---+----+------+ 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!