How to pivot Spark DataFrame?

后端 未结 10 2101
闹比i
闹比i 2020-11-21 06:43

I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality

10条回答
  •  广开言路
    2020-11-21 06:58

    As mentioned by David Anderson Spark provides pivot function since version 1.6. General syntax looks as follows:

    df
      .groupBy(grouping_columns)
      .pivot(pivot_column, [values]) 
      .agg(aggregate_expressions)
    

    Usage examples using nycflights13 and csv format:

    Python:

    from pyspark.sql.functions import avg
    
    flights = (sqlContext
        .read
        .format("csv")
        .options(inferSchema="true", header="true")
        .load("flights.csv")
        .na.drop())
    
    flights.registerTempTable("flights")
    sqlContext.cacheTable("flights")
    
    gexprs = ("origin", "dest", "carrier")
    aggexpr = avg("arr_delay")
    
    flights.count()
    ## 336776
    
    %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
    ## 10 loops, best of 3: 1.03 s per loop
    

    Scala:

    val flights = sqlContext
      .read
      .format("csv")
      .options(Map("inferSchema" -> "true", "header" -> "true"))
      .load("flights.csv")
    
    flights
      .groupBy($"origin", $"dest", $"carrier")
      .pivot("hour")
      .agg(avg($"arr_delay"))
    

    Java:

    import static org.apache.spark.sql.functions.*;
    import org.apache.spark.sql.*;
    
    Dataset df = spark.read().format("csv")
            .option("inferSchema", "true")
            .option("header", "true")
            .load("flights.csv");
    
    df.groupBy(col("origin"), col("dest"), col("carrier"))
            .pivot("hour")
            .agg(avg(col("arr_delay")));
    

    R / SparkR:

    library(magrittr)
    
    flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)
    
    flights %>% 
      groupBy("origin", "dest", "carrier") %>% 
      pivot("hour") %>% 
      agg(avg(column("arr_delay")))
    

    R / sparklyr

    library(dplyr)
    
    flights <- spark_read_csv(sc, "flights", "flights.csv")
    
    avg.arr.delay <- function(gdf) {
       expr <- invoke_static(
          sc,
          "org.apache.spark.sql.functions",
          "avg",
          "arr_delay"
        )
        gdf %>% invoke("agg", expr, list())
    }
    
    flights %>% 
      sdf_pivot(origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay)
    

    SQL:

    Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.

    CREATE TEMPORARY VIEW flights 
    USING csv 
    OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;
    
     SELECT * FROM (
       SELECT origin, dest, carrier, arr_delay, hour FROM flights
     ) PIVOT (
       avg(arr_delay)
       FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                    13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
     );
    

    Example data:

    "year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour"
    2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
    2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
    2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
    2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
    2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
    2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
    2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
    2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
    2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
    2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00
    

    Performance considerations:

    Generally speaking pivoting is an expensive operation.

    • if you can, try to provide values list, as this avoids an extra hit to compute the uniques:

      vs = list(range(25))
      %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count()
      ## 10 loops, best of 3: 392 ms per loop
      
    • in some cases it proved to be beneficial (likely no longer worth the effort in 2.0 or later) to repartition and / or pre-aggregate the data

    • for reshaping only, you can use first: Pivot String column on Pyspark Dataframe

    Related questions:

    • How to melt Spark DataFrame?
    • Unpivot in spark-sql/pyspark
    • Transpose column to row with Spark

提交回复
热议问题