As the subject describes, I have a PySpark Dataframe that I need to cast two columns into a new column that is a list of tuples based the value of a third column. This cast
Assuming your Dataframe
is called df
:
from pyspark.sql.functions import struct
from pyspark.sql.functions import collect_list
gdf = (df.select("product_id", "category", struct("purchase_date", "warranty_days").alias("pd_wd"))
.groupBy("product_id")
.pivot("category")
.agg(collect_list("pd_wd")))
Essentially, you have to group the purchase_date
and warranty_days
into a single column using struct()
. Then, you are just grouping by product_id
, pivoting by category
, can aggregating as collect_list()
.
In the case that you have performance issues with pivot the approach below is another solution to the same problem although it allows you to have more control by splitting the job into phases for each category with a for loop. For every iteration this will append the new data for the category_x into acc_df which will hold the accumulated results.
schema = ArrayType(
StructType((
StructField("p_date", StringType(), False),
StructField("d_warranty", StringType(), False)
))
)
tuple_list_udf = udf(tuple_list, schema)
buf_size = 5 # if you get OOM error decrease this to persist more often
categories = df.select("category").distinct().collect()
acc_df = spark.createDataFrame(sc.emptyRDD(), df.schema) # create an empty df which holds the accumulated results for each category
for idx, c in enumerate(categories):
col_name = c[0].replace(" ", "_") # spark complains for columns containing space
cat_df = df.where(df["category"] == c[0]) \
.groupBy("product_id") \
.agg(
F.collect_list(F.col("purchase_date")).alias("p_date"),
F.collect_list(F.col("days_warranty")).alias("d_warranty")) \
.withColumn(col_name, tuple_list_udf(F.col("p_date"), F.col("d_warranty"))) \
.drop("p_date", "d_warranty")
if idx == 0:
acc_df = cat_df
else:
acc_df = acc_df \
.join(cat_df.alias("cat_df"), "product_id") \
.drop(F.col("cat_df.product_id"))
# you can persist here every buf_size iterations
if idx + 1 % buf_size == 0:
acc_df = acc_df.persist()
The function tuple_list is responsible for generating a list with tuples from purchase_date and days_warranty columns.
def tuple_list(pdl, dwl):
return list(zip(pdl, dwl))
The output of this will be:
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|product_id |CATEGORY_B |CATEGORY_A |
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|02147465400|[[2017-04-16 00:00:00, 90], [2018-09-16 00:00:00, 90], [2017-10-09 00:00:00, 90], [2018-01-12 00:00:00, 90], [2018-07-11 00:00:00, 90], [2017-01-21 00:00:00, 90], [2018-04-14 00:00:00, 90], [2017-01-05 00:00:00, 90], [2017-07-15 00:00:00, 90]]|[[2017-06-14 00:00:00, 30], [2018-08-14 00:00:00, 30], [2018-01-11 00:00:00, 30], [2018-04-12 00:00:00, 30], [2017-10-11 00:00:00, 30], [2017-05-16 00:00:00, 30], [2018-05-15 00:00:00, 30], [2017-04-15 00:00:00, 30], [2017-02-15 00:00:00, 30], [2018-02-12 00:00:00, 30], [2017-01-21 00:00:00, 30], [2018-07-11 00:00:00, 30], [2018-06-14 00:00:00, 30], [2017-03-16 00:00:00, 30], [2017-07-20 00:00:00, 30], [2018-08-23 00:00:00, 30], [2017-09-12 00:00:00, 30], [2018-03-12 00:00:00, 30], [2017-12-12 00:00:00, 30], [2017-08-14 00:00:00, 30], [2017-11-11 00:00:00, 30]]|
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+