How to use windowing functions efficiently to decide next N number of rows based on N number of previous values

爷,独闯天下 提交于 2020-12-09 06:14:17

问题


Hi i have the following data.

+----------+----+-------+-----------------------+
|      date|item|avg_val|conditions             |
+----------+----+-------+-----------------------+
|01-10-2020|   x|     10|                      0|
|02-10-2020|   x|     10|                      0|
|03-10-2020|   x|     15|                      1|
|04-10-2020|   x|     15|                      1|
|05-10-2020|   x|      5|                      0|
|06-10-2020|   x|     13|                      1|
|07-10-2020|   x|     10|                      1|
|08-10-2020|   x|     10|                      0|
|09-10-2020|   x|     15|                      1|
|01-10-2020|   y|     10|                      0|
|02-10-2020|   y|     18|                      0|
|03-10-2020|   y|      6|                      1|
|04-10-2020|   y|     10|                      0|
|05-10-2020|   y|     20|                      0|
+----------+----+-------+-----------------------+

I am tring to create a new column called flag level based on

  1. if flag value is 0 then new column value will be 0.
  2. if the flag is 1 then new column will be 1 and next four N number of rows will be zero i.e no need to check next N value. this process will be applied for each item , that is partition by item will work.

I have used here N = 4,

I have used the below code but not effienntly windowing function is there any optimized way.

DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;

 

 DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;

 


SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0 
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;

The output like as

+----------+----+-------+-----------------------+-----+
|      date|item|avg_val|conditions             |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020|   x|     10|                      0|    0|
|02-10-2020|   x|     10|                      0|    0|
|03-10-2020|   x|     15|                      1|    1|
|04-10-2020|   x|     15|                      1|    0|
|05-10-2020|   x|      5|                      0|    0|
|06-10-2020|   x|     13|                      1|    0|
|07-10-2020|   x|     10|                      1|    0|
|08-10-2020|   x|     10|                      0|    0|
|09-10-2020|   x|     15|                      1|    1|
|01-10-2020|   y|     10|                      0|    0|
|02-10-2020|   y|     18|                      0|    0|
|03-10-2020|   y|      6|                      1|    1|
|04-10-2020|   y|     10|                      0|    0|
|05-10-2020|   y|     20|                      0|    0|
+----------+----+-------+-----------------------+-----+

But i am unable to get the ouput , i have tried more.


回答1:


As suggested in the comments(by @nbk and @Akina), you will need some sort of iterator to implement the logic. With SparkSQL and Spark version 2.4+, we can use the builtin function aggregate and set an array of structs plus a counter as the accumulator. Below is an example dataframe and table named record(assume values in conditions column are either 0 or 1):

val df = Seq(
    ("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
    ("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
    ("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
    ("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
    ("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")

df.createOrReplaceTempView("record")

SQL:

spark.sql("""
  SELECT t1.item, m.*
  FROM (
    SELECT item,
      sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
    FROM record
    GROUP BY item
  ) as t1 LATERAL VIEW OUTER inline(
    aggregate(
      /* expr: set up array `dta` from the 2nd element to the last 
       *       notice that indices for slice function is 1-based, dta[i] is 0-based
       */
      slice(dta,2,size(dta)),
      /* start: set up and initialize `acc` to a struct containing two fields:
       * - dta: an array of structs with a single element dta[0]
       * - counter: number of rows after flag=1, can be from `0` to `N+1`
       */
      (array(dta[0]) as dta, dta[0].conditions as counter),
      /* merge: iterate through the `expr` using x and update two fields of `acc`
       * - dta: append values from x to acc.dta array using concat + array functions
       *        update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
       * - counter: increment by 1 if acc.counter is between 1 and 4
       *            , otherwise set value to x.conditions
       */
      (acc, x) -> named_struct(
          'dta', concat(acc.dta, array(named_struct(
              'date', x.date,
              'avg_val', x.avg_val,
              'conditions', x.conditions,
              'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
            ))),
          'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
        ),
      /* finish: retrieve acc.dta only and discard acc.counter */
      acc -> acc.dta
    )
  ) m
""").show(50)

Result:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   0|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

Where:

  1. use groupby to collect rows for the same item into an array of structs named dta column with 4 fields: date, avg_val, conditions and flag and sorted by date
  2. use aggregate function to iterate through the above array of structs, update the flag field based on counter and conditions (details see the above SQL code comments)
  3. use Lateral VIEW and inline function to explode the resulting array of structs from the aggregate function

Notes:

(1) the proposed SQL is for N=4, where we have acc.counter IN (0,5) and acc.counter < 5 in the SQL. For any N, adjust the above to: acc.counter IN (0,N+1) and acc.counter < N+1, the below shows the result for N=2 with the same sample data:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   1|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

(2) we use dta[0] to initialize acc which includes both the values and datatypes of its fields. Ideally, we should make sure data types of these fields right so that all calculations are correctly conducted. for example when calculating acc.counter, if conditions is StringType, acc.counter+1 will return a StringType with a DoubleType value

spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
|                                    3.0|
+---------------------------------------+

Which could yield floating-point errors when comparing their value with integers using acc.counter IN (0,5) or acc.counter < 5. Based on OP's feedback, this produced incorrect result without any WARNING/ERROR message.

  • One workaround is to specify exact field types using CAST when setting up the 2nd argument of aggregate function so it reports ERROR when any types mismatch, see below:

    CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),
    
  • Another solution it to force types when creating dta column, in this example, see int(conditions) as conditions in below code:

    SELECT item,
      sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
    FROM record
    GROUP BY item
    
  • we can also force datatype inside the calculating, for example, see int(acc.counter+1) below:

    IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)      
    


来源:https://stackoverflow.com/questions/64660047/how-to-use-windowing-functions-efficiently-to-decide-next-n-number-of-rows-based

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!