Flatten Nested Struct in PySpark Array

跟風遠走 提交于 2021-02-04 16:37:26


Given a schema like:

|-- first_name: string
|-- last_name: string
|-- degrees: array
|    |-- element: struct
|    |    |-- school: string
|    |    |-- advisors: struct
|    |    |    |-- advisor1: string
|    |    |    |-- advisor2: string

How can I get a schema like:

|-- first_name: string
|-- last_name: string
|-- degrees: array
|    |-- element: struct
|    |    |-- school: string
|    |    |-- advisor1: string
|    |    |-- advisor2: string

Currently, I explode the array, flatten the structure by selecting advisor.* and then group by first_name, last_name and rebuild the array with collect_list. I'm hoping there's a cleaner/shorter way to do this. Currently, there's a lot of pain renaming some fields and stuff that I don't want to get into here. Thanks!


You can use udf to change the datatype of nested columns in dataframe. Suppose you have read the dataframe as df1

from pyspark.sql.functions import udf
from pyspark.sql.types import *

def foo(data):

    return(list(map(lambda x: (x["school"], x["advisors"]["advisor1"],\
                               x["advisors"]["advisor2"]), data)))

struct = ArrayType(StructType([StructField("school", StringType()),
                              StructField("advisor1", StringType()),
                              StructField("advisor2", StringType())]))
udf_foo = udf(foo, struct)

df2 = df1.withColumn("degrees",udf_foo("degrees"))


 |-- degrees: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- school: string (nullable = true)
 |    |    |-- advisor1: string (nullable = true)
 |    |    |-- advisor2: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)


Here's a more generic solution which can flatten multiple nested struct layers:

def flatten_df(nested_df, layers):
    flat_cols = []
    nested_cols = []
    flat_df = []

    flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
    nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

    flat_df.append(nested_df.select(flat_cols[0] +
                                for nc in nested_cols[0]
                                for c in nested_df.select(nc+'.*').columns])
    for i in range(1, layers):
        print (flat_cols[i-1])
        flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
        nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

        flat_df.append(flat_df[i-1].select(flat_cols[i] +
                                    for nc in nested_cols[i]
                                    for c in flat_df[i-1].select(nc+'.*').columns])

    return flat_df[-1]

just call with:

my_flattened_df = flatten_df(my_df_having_structs, 3)

(second parameter is the level of layers to be flattened, in my case it's 3)

