How to retrieve all columns using pyspark collect_list functions

前端 未结 3 1121
梦如初夏
梦如初夏 2021-01-14 05:42

I have a pyspark 2.0.1. I\'m trying to groupby my data frame & retrieve the value for all the fields from my data frame. I found that

z=data1.groupby(\'         


        
相关标签:
3条回答
  • 2021-01-14 05:50

    Use struct to combine the columns before calling groupBy

    suppose you have a dataframe

    df = spark.createDataFrame(sc.parallelize([(0,1,2),(0,4,5),(1,7,8),(1,8,7)])).toDF("a","b","c")
    
    df = df.select("a", f.struct(["b","c"]).alias("newcol"))
    df.show()
    +---+------+
    |  a|newcol|
    +---+------+
    |  0| [1,2]|
    |  0| [4,5]|
    |  1| [7,8]|
    |  1| [8,7]|
    +---+------+
    df = df.groupBy("a").agg(f.collect_list("newcol").alias("collected_col"))
    df.show()
    +---+--------------+
    |  a| collected_col|
    +---+--------------+
    |  0|[[1,2], [4,5]]|
    |  1|[[7,8], [8,7]]|
    +---+--------------+
    

    Aggregation operation can be done only on single columns.

    After aggregation, You can collect the result and iterate over it to separate the combined columns generate the index dict. or you can write a udf to separate the combined columns.

    from pyspark.sql.types import *
    def foo(x):
        x1 = [y[0] for y in x]
        x2 = [y[1] for y in x]
        return(x1,x2)
    
    st = StructType([StructField("b", ArrayType(LongType())), StructField("c", ArrayType(LongType()))])
    udf_foo = udf(foo, st)
    df = df.withColumn("ncol", 
                      udf_foo("collected_col")).select("a",
                      col("ncol").getItem("b").alias("b"), 
                      col("ncol").getItem("c").alias("c"))
    df.show()
    
    +---+------+------+
    |  a|     b|     c|
    +---+------+------+
    |  0|[1, 4]|[2, 5]|
    |  1|[7, 8]|[8, 7]|
    +---+------+------+
    
    0 讨论(0)
  • 2021-01-14 05:53

    Actually we can do it in pyspark 2.2 .

    First we need create a constant column ("Temp"), groupBy with that column ("Temp") and apply agg by pass iterable *exprs in which expression of collect_list exits.

    Below is the code:

    import pyspark.sql.functions as ftions
    import functools as ftools
    
    def groupColumnData(df, columns):
          df = df.withColumn("Temp", ftions.lit(1))
          exprs = [ftions.collect_list(colName) for colName in columns]
          df = df.groupby('Temp').agg(*exprs)
          df = df.drop("Temp")
          df = df.toDF(*columns)
          return df
    

    Input Data:

    df.show()
    +---+---+---+
    |  a|  b|  c|
    +---+---+---+
    |  0|  1|  2|
    |  0|  4|  5|
    |  1|  7|  8|
    |  1|  8|  7|
    +---+---+---+
    

    Output Data:

    df.show()
    
        +------------+------------+------------+
        |           a|           b|           c|
        +------------+------------+------------+
        |[0, 0, 1, 1]|[1, 4, 7, 8]|[2, 5, 8, 7]|
        +------------+------------+------------+
    
    0 讨论(0)
  • 2021-01-14 06:13

    in spark 2.4.4 and python 3.7 (I guess its also relevant for previous spark and python version) --
    My suggestion is a based on pauli's answer,
    instead of creating the struct and then using the agg function, create the struct inside collect_list:

    df = spark.createDataFrame([(0,1,2),(0,4,5),(1,7,8),(1,8,7)]).toDF("a","b","c")
    df.groupBy("a").agg(collect_list(struct(["b","c"])).alias("res")).show()
    

    result :

    +---+-----------------+
    |  a|res              |
    +---+-----------------+
    |  0|[[1, 2], [4, 5]] |
    |  1|[[7, 8], [8, 7]] |
    +---+-----------------+
    
    0 讨论(0)
提交回复
热议问题