How to split a list to multiple columns in Pyspark?

后端 未结 3 2006
余生分开走
余生分开走 2020-11-27 19:21

I have:

key   value
a    [1,2,3]
b    [2,3,4]

I want:

key value1 value2 value3
a     1      2      3
b     2      3      4
         


        
相关标签:
3条回答
  • 2020-11-27 20:01

    I'd like to add the case of sized lists (arrays) to pault answer.

    In the case that our column contains medium sized arrays (or large sized ones) it is still possible to split them in columns.

    from pyspark.sql.types import *          # Needed to define DataFrame Schema.
    from pyspark.sql.functions import expr   
    
    # Define schema to create DataFrame with an array typed column.
    mySchema = StructType([StructField("V1", StringType(), True),
                           StructField("V2", ArrayType(IntegerType(),True))])
    
    df = spark.createDataFrame([['A', [1, 2, 3, 4, 5, 6, 7]], 
                                ['B', [8, 7, 6, 5, 4, 3, 2]]], schema= mySchema)
    
    # Split list into columns using 'expr()' in a comprehension list.
    arr_size = 7
    df = df.select(['V1', 'V2']+[expr('V2[' + str(x) + ']') for x in range(0, arr_size)])
    
    # It is posible to define new column names.
    new_colnames = ['V1', 'V2'] + ['val_' + str(i) for i in range(0, arr_size)] 
    df = df.toDF(*new_colnames)
    

    The result is:

    df.show(truncate= False)
    
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    |V1 |V2                   |val_0|val_1|val_2|val_3|val_4|val_5|val_6|
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    |A  |[1, 2, 3, 4, 5, 6, 7]|1    |2    |3    |4    |5    |6    |7    |
    |B  |[8, 7, 6, 5, 4, 3, 2]|8    |7    |6    |5    |4    |3    |2    |
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    
    0 讨论(0)
  • 2020-11-27 20:05

    It depends on the type of your "list":

    • If it is of type ArrayType():

      df = hc.createDataFrame(sc.parallelize([['a', [1,2,3]], ['b', [2,3,4]]]), ["key", "value"])
      df.printSchema()
      df.show()
      root
       |-- key: string (nullable = true)
       |-- value: array (nullable = true)
       |    |-- element: long (containsNull = true)
      

      you can access the values like you would with python using []:

      df.select("key", df.value[0], df.value[1], df.value[2]).show()
      +---+--------+--------+--------+
      |key|value[0]|value[1]|value[2]|
      +---+--------+--------+--------+
      |  a|       1|       2|       3|
      |  b|       2|       3|       4|
      +---+--------+--------+--------+
      
      +---+-------+
      |key|  value|
      +---+-------+
      |  a|[1,2,3]|
      |  b|[2,3,4]|
      +---+-------+
      
    • If it is of type StructType(): (maybe you built your dataframe by reading a JSON)

      df2 = df.select("key", psf.struct(
              df.value[0].alias("value1"), 
              df.value[1].alias("value2"), 
              df.value[2].alias("value3")
          ).alias("value"))
      df2.printSchema()
      df2.show()
      root
       |-- key: string (nullable = true)
       |-- value: struct (nullable = false)
       |    |-- value1: long (nullable = true)
       |    |-- value2: long (nullable = true)
       |    |-- value3: long (nullable = true)
      
      +---+-------+
      |key|  value|
      +---+-------+
      |  a|[1,2,3]|
      |  b|[2,3,4]|
      +---+-------+
      

      you can directly 'split' the column using *:

      df2.select('key', 'value.*').show()
      +---+------+------+------+
      |key|value1|value2|value3|
      +---+------+------+------+
      |  a|     1|     2|     3|
      |  b|     2|     3|     4|
      +---+------+------+------+
      
    0 讨论(0)
  • 2020-11-27 20:06

    @jordi Aceiton thanks for the solution. I tried to make it more concise, tried to remove the loop for renaming the newly created column names, doing it while creating the columns. Using df.columns to fetch all the column names rather creating it manually.

        from pyspark.sql.types import *          
        from pyspark.sql.functions import * 
        from pyspark import Row
    
        df = spark.createDataFrame([Row(index=1, finalArray = [1.1,2.3,7.5], c =4),Row(index=2, finalArray = [9.6,4.1,5.4], c= 4)])
        #collecting all the column names as list
        dlist = df.columns
        #Appending new columns to the dataframe
        df.select(dlist+[(col("finalArray")[x]).alias("Value"+str(x+1)) for x in range(0, 3)]).show()
    

    Output:

         +---------------+-----+------+------+------+
         |  finalArray   |index|Value1|Value2|Value3|
         +---------------+-----+------+------+------+
         |[1.1, 2.3, 7.5]|  1  |   1.1|   2.3|   7.5|
         |[9.6, 4.1, 5.4]|  2  |   9.6|   4.1|   5.4|
         +---------------+-----+------+------+------+
    
    0 讨论(0)
提交回复
热议问题