duplicating records between date gaps within a selected time interval in a PySpark dataframe

守給你的承諾、 提交于 2021-02-08 09:45:10


I have a PySpark dataframe that keeps track of changes that occur in a product's price and status over months. This means that a new row is created only when a change occurred (in either status or price) compared to the previous month, like in the dummy data below

    |product_id| status    | price| month  |
    |1         | available | 5    | 2019-10|
    |1         | available | 8    | 2020-08|
    |1         | limited   | 8    | 2020-10|
    |2         | limited   | 1    | 2020-09|
    |2         | limited   | 3    | 2020-10|

I would like to create a dataframe that shows the values for each of the last 6 months. This means that I need to duplicate the records whenever there is a gap in the above dataframe. For example, if the last 6 months are 2020-07, 2020-08, ... 2020-12, then the result for the above dataframe should be

    |product_id| status    | price| month  |
    |1         | available | 5    | 2020-07|
    |1         | available | 8    | 2020-08|
    |1         | available | 8    | 2020-09|
    |1         | limited   | 8    | 2020-10|
    |1         | limited   | 8    | 2020-11|
    |1         | limited   | 8    | 2020-12|
    |2         | limited   | 1    | 2020-09|
    |2         | limited   | 3    | 2020-10|
    |2         | limited   | 3    | 2020-11|
    |2         | limited   | 3    | 2020-12|

Notice that for product_id = 1 there was an older record from 2019-10 that was propagated until 2020-08 and then trimmed, whereas for product_id = 2 there were no records prior to 2020-09 and thus the months 2020-07, 2020-08 were not filled for it (as the product did not exist prior to 2020-09).

Since the dataframe consists of millions of records, a "brute-force" solution using for loops and checking for each product_id is rather slow. It seems that it should be possible to solve this using window functions, by creating another column next_month and then filling in the gaps based on that column, but I don't know how to achieve that.


With Respect to the @jxc comment, I have prepared the answer for this use case.

Following is the code snippet.

  1. Import the spark SQL functions

    from pyspark.sql import functions as F, Window

  2. Prepare the sample data

    simpleData = ((1,"Available",5,"2020-07"),                                                              

    columns= ["product_id", "status", "price", "month"]
  1. Creating dataframe of sample data

    df = spark.createDataFrame(data = simpleData, schema = columns)

  2. Add date column in dataframe to get proper formatted date

    df0 = df.withColumn("date",F.to_date('month','yyyy-MM'))


    |product_id|   status|price|  month|      date|                                               
    |         1|Available|    5|2020-07|2020-07-01|                                                 
    |         1|Available|    8|2020-08|2020-08-01|                                                
    |         1|  Limited|    8|2020-12|2020-12-01|                                                
    |         2|  Limited|    1|2020-09|2020-09-01|                                                
    |         2|  Limited|    3|2020-12|2020-12-01|                                                
  1. Creating WinSpec w1 and use Window aggregate function lead to find the next date over(w1), convert it to the previous months to set up date sequences:
    w1 = Window.partitionBy('product_id').orderBy('date')
    df1 = df0.withColumn('end_date',F.coalesce(F.add_months(F.lead('date').over(w1),-1),'date'))

    |product_id|   status|price|  month|      date|  end_date|                                                      
    |         1|Available|    5|2020-07|2020-07-01|2020-07-01|                                                      
    |         1|Available|    8|2020-08|2020-08-01|2020-11-01|                                                            
    |         1|  Limited|    8|2020-12|2020-12-01|2020-12-01|                                                                     
    |         2|  Limited|    1|2020-09|2020-09-01|2020-11-01|                                                                            
    |         2|  Limited|    3|2020-12|2020-12-01|2020-12-01|                                                                                   
  1. Using months_between(end_date, date) to calculate # of months between two dates, and use transform function to iterate through sequence(0, #months), create a named_struct with date=add_months(date,i) and price=IF(i=0,price,price), use inline_outer to explode the array of structs.
    df2 = df1.selectExpr("product_id", "status", inline_outer( transform( sequence(0,int(months_between(end_date, date)),1), i -> (add_months(date,i) as date, IF(i=0,price,price) as price) ) ) )


    |product_id|   status|      date|price|                                                             
    |         1|Available|2020-07-01|    5|                                                              
    |         1|Available|2020-08-01|    8|                                                  
    |         1|Available|2020-09-01|    8|                                                           
    |         1|Available|2020-10-01|    8|                                                             
    |         1|Available|2020-11-01|    8|                                                                 
    |         1|  Limited|2020-12-01|    8|                                                                
    |         2|  Limited|2020-09-01|    1|                                                                                 
    |         2|  Limited|2020-10-01|    1|                                                    
    |         2|  Limited|2020-11-01|    1|                                                                          
    |         2|  Limited|2020-12-01|    3|                                                          
  1. Partitioning the dataframe on product_id and adding a rank column in df3 to get row number for each row. Then, Storing the maximum of rank column value with new column max_rank for each product_id and storing max_rank in to df4
    w2 = Window.partitionBy('product_id').orderBy('date')                                                            
    df3 = df2.withColumn('rank',F.row_number().over(w2))                                                                 
    Schema: DataFrame[product_id: bigint, status: string, date: date, price: bigint, rank: int]
    |product_id|   status|      date|price|rank|
    |         1|Available|2020-07-01|    5|   1|
    |         1|Available|2020-08-01|    8|   2|
    |         1|Available|2020-09-01|    8|   3|
    |         1|Available|2020-10-01|    8|   4|
    |         1|Available|2020-11-01|    8|   5|
    |         1|  Limited|2020-12-01|    8|   6|
    |         2|  Limited|2020-09-01|    1|   1|
    |         2|  Limited|2020-10-01|    1|   2|
    |         2|  Limited|2020-11-01|    1|   3|
    |         2|  Limited|2020-12-01|    3|   4|

    df4 = df3.groupBy("product_id").agg(F.max('rank').alias('max_rank'))                                                           
    Schema: DataFrame[product_id: bigint, max_rank: int]
    |         1|       6|
    |         2|       4|

  1. Joining df3 and df4 dataframes on product_id get max_rank
    df5 = df3.join(df4,df3.product_id == df4.product_id,"inner") \
    Schema: DataFrame[product_id: bigint, status: string, date: date, price: bigint, rank: int, max_rank: int]
    |product_id|   status|      date|price|rank|max_rank|
    |         1|Available|2020-07-01|    5|   1|       6|
    |         1|Available|2020-08-01|    8|   2|       6|
    |         1|Available|2020-09-01|    8|   3|       6|
    |         1|Available|2020-10-01|    8|   4|       6|
    |         1|Available|2020-11-01|    8|   5|       6|
    |         1|  Limited|2020-12-01|    8|   6|       6|
    |         2|  Limited|2020-09-01|    1|   1|       4|
    |         2|  Limited|2020-10-01|    1|   2|       4|
    |         2|  Limited|2020-11-01|    1|   3|       4|
    |         2|  Limited|2020-12-01|    3|   4|       4|

  1. Then finally filtering the df5 dataframe using between function to get the latest 6 months data.
    FinalResultDF = df5.filter(F.col('rank') \                                      
                         .between(F.when((F.col('max_rank') > 5),(F.col('max_rank')-6)).otherwise(0),F.col('max_rank'))) \

    |product_id|status   |date      |price|                                                
    |1         |Available|2020-07-01|5    |                                                                                
    |1         |Available|2020-08-01|8    |                                                                                          
    |1         |Available|2020-09-01|8    |                                                                                                           
    |1         |Available|2020-10-01|8    |                                                                                                             
    |1         |Available|2020-11-01|8    |                                                                                                               
    |1         |Limited  |2020-12-01|8    |                                                                                                                     
    |2         |Limited  |2020-09-01|1    |                                                                                                                     
    |2         |Limited  |2020-10-01|1    |                                                                                                                        
    |2         |Limited  |2020-11-01|1    |                                                                                                                      
    |2         |Limited  |2020-12-01|3    |                                                                                                         


Using spark-sql:

Given input dataframe:

val df = spark.sql(""" with t1 (
 select  1 c1,   'available' c2, 5 c3,   '2019-10' c4  union all
 select  1 c1,   'available' c2, 8 c3,   '2020-08' c4  union all
 select  1 c1,   'limited' c2, 8 c3,   '2020-10' c4  union all
 select  2 c1,   'limited' c2, 1 c3,   '2020-09' c4  union all
 select  2 c1,   'limited' c2, 3 c3,   '2020-10' c4 
  )  select   c1  product_id,   c2   status    ,   c3   price,   c4  month      from t1


|product_id|status   |price|month  |
|1         |available|5    |2019-10|
|1         |available|8    |2020-08|
|1         |limited  |8    |2020-10|
|2         |limited  |1    |2020-09|
|2         |limited  |3    |2020-10|

Filter on the date window i.e 6 months from 2020-07 to 2020-12 and store them in df1

val df1 = spark.sql("""
select * from df where month > '2020-07' and month < '2020-12' 

|product_id|status   |price|month  |
|1         |available|8    |2020-08|
|1         |limited  |8    |2020-10|
|2         |limited  |1    |2020-09|
|2         |limited  |3    |2020-10|

Lower boundary - Get the maximum when the month <='2020-07'. Overwrite the month as '2020-07'

val df2 = spark.sql("""
select product_id, status, price, '2020-07' month from df  where (product_id,month) in 
( select product_id, max(month) from df where month <= '2020-07' group by 1 ) 

|product_id|status   |price|month  |
|1         |available|5    |2020-07|

Upper boundary - Get the maximum using <='2020-12'. Overwrite the month as '2020-12'

val df3 = spark.sql("""
select product_id, status, price, '2020-12' month from df where (product_id, month) in  
( select product_id, max(month) from df where month <= '2020-12' group by 1 ) 

|product_id|status |price|month  |
|1         |limited|8    |2020-12|
|2         |limited|3    |2020-12|

Now union all the 3 and store it in df4

val df4 = spark.sql("""
select  product_id, status, price,  month from df1  union all 
select  product_id, status, price,  month from df2  union all 
select  product_id, status, price,  month from df3
order by product_id, month

|product_id|status   |price|month  |
|1         |available|5    |2020-07|
|1         |available|8    |2020-08|
|1         |limited  |8    |2020-10|
|1         |limited  |8    |2020-12|
|2         |limited  |1    |2020-09|
|2         |limited  |3    |2020-10|
|2         |limited  |3    |2020-12|

Result: Use sequence(date1,date2, interval 1 month) to generate date array for the missing months. Explode the array and you get the results.

select product_id, status, price, month, explode(dt) res_month from 
select t1.*, 
case when months_between(lm||'-01',month||'-01')=1.0 then array(month||'-01')
     when month='2020-12' then array(month||'-01')
     else sequence(to_date(month||'-01'), add_months(to_date(lm||'-01'),-1), interval 1 month ) 
end dt 
     from (
            select product_id, status, price, month, 
            lead(month) over(partition by product_id order by month) lm 
            from df4 
          ) t1 
    ) t2 
  order by product_id, res_month

|product_id|status   |price|month  |res_month |
|1         |available|5    |2020-07|2020-07-01|
|1         |available|8    |2020-08|2020-08-01|
|1         |available|8    |2020-08|2020-09-01|
|1         |limited  |8    |2020-10|2020-10-01|
|1         |limited  |8    |2020-10|2020-11-01|
|1         |limited  |8    |2020-12|2020-12-01|
|2         |limited  |1    |2020-09|2020-09-01|
|2         |limited  |3    |2020-10|2020-10-01|
|2         |limited  |3    |2020-10|2020-11-01|
|2         |limited  |3    |2020-12|2020-12-01|

