问题
Hi i have the following data.
+----------+----+-------+-----------------------+
| date|item|avg_val|conditions |
+----------+----+-------+-----------------------+
|01-10-2020| x| 10| 0|
|02-10-2020| x| 10| 0|
|03-10-2020| x| 15| 1|
|04-10-2020| x| 15| 1|
|05-10-2020| x| 5| 0|
|06-10-2020| x| 13| 1|
|07-10-2020| x| 10| 1|
|08-10-2020| x| 10| 0|
|09-10-2020| x| 15| 1|
|01-10-2020| y| 10| 0|
|02-10-2020| y| 18| 0|
|03-10-2020| y| 6| 1|
|04-10-2020| y| 10| 0|
|05-10-2020| y| 20| 0|
+----------+----+-------+-----------------------+
I am tring to create a new column called flag level based on
- if flag value is 0 then new column value will be 0.
- if the flag is 1 then new column will be 1 and next four N number of rows will be zero i.e no need to check next N value. this process will be applied for each item , that is partition by item will work.
I have used here N = 4,
I have used the below code but not effienntly windowing function is there any optimized way.
DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;
DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;
SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;
The output like as
+----------+----+-------+-----------------------+-----+
| date|item|avg_val|conditions |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020| x| 10| 0| 0|
|02-10-2020| x| 10| 0| 0|
|03-10-2020| x| 15| 1| 1|
|04-10-2020| x| 15| 1| 0|
|05-10-2020| x| 5| 0| 0|
|06-10-2020| x| 13| 1| 0|
|07-10-2020| x| 10| 1| 0|
|08-10-2020| x| 10| 0| 0|
|09-10-2020| x| 15| 1| 1|
|01-10-2020| y| 10| 0| 0|
|02-10-2020| y| 18| 0| 0|
|03-10-2020| y| 6| 1| 1|
|04-10-2020| y| 10| 0| 0|
|05-10-2020| y| 20| 0| 0|
+----------+----+-------+-----------------------+-----+
But i am unable to get the ouput , i have tried more.
回答1:
As suggested in the comments(by @nbk and @Akina), you will need some sort of iterator to implement the logic. With SparkSQL and Spark version 2.4+, we can use the builtin function aggregate and set an array of structs plus a counter as the accumulator. Below is an example dataframe and table named record
(assume values in conditions
column are either 0
or 1
):
val df = Seq(
("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")
df.createOrReplaceTempView("record")
SQL:
spark.sql("""
SELECT t1.item, m.*
FROM (
SELECT item,
sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
FROM record
GROUP BY item
) as t1 LATERAL VIEW OUTER inline(
aggregate(
/* expr: set up array `dta` from the 2nd element to the last
* notice that indices for slice function is 1-based, dta[i] is 0-based
*/
slice(dta,2,size(dta)),
/* start: set up and initialize `acc` to a struct containing two fields:
* - dta: an array of structs with a single element dta[0]
* - counter: number of rows after flag=1, can be from `0` to `N+1`
*/
(array(dta[0]) as dta, dta[0].conditions as counter),
/* merge: iterate through the `expr` using x and update two fields of `acc`
* - dta: append values from x to acc.dta array using concat + array functions
* update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
* - counter: increment by 1 if acc.counter is between 1 and 4
* , otherwise set value to x.conditions
*/
(acc, x) -> named_struct(
'dta', concat(acc.dta, array(named_struct(
'date', x.date,
'avg_val', x.avg_val,
'conditions', x.conditions,
'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
))),
'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
),
/* finish: retrieve acc.dta only and discard acc.counter */
acc -> acc.dta
)
) m
""").show(50)
Result:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 0|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
Where:
- use
groupby
to collect rows for the same item into an array of structs named dta column with 4 fields: date, avg_val, conditions and flag and sorted by date - use
aggregate
function to iterate through the above array of structs, update the flag field based on counter and conditions (details see the above SQL code comments) - use
Lateral VIEW
and inline function to explode the resulting array of structs from the aggregate function
Notes:
(1) the proposed SQL is for N=4, where we have acc.counter IN (0,5)
and acc.counter < 5
in the SQL. For any N, adjust the above to: acc.counter IN (0,N+1)
and acc.counter < N+1
, the below shows the result for N=2
with the same sample data:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 1|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
(2) we use dta[0]
to initialize acc
which includes both the values and datatypes of its fields. Ideally, we should make sure data types of these fields right so that all calculations are correctly conducted. for example when calculating acc.counter
, if conditions
is StringType, acc.counter+1
will return a StringType with a DoubleType value
spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
| 3.0|
+---------------------------------------+
Which could yield floating-point errors when comparing their value with integers using acc.counter IN (0,5)
or acc.counter < 5
. Based on OP's feedback, this produced incorrect result without any WARNING/ERROR message.
One workaround is to specify exact field types using CAST when setting up the 2nd argument of aggregate function so it reports ERROR when any types mismatch, see below:
CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),
Another solution it to force types when creating
dta
column, in this example, seeint(conditions) as conditions
in below code:SELECT item, sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta FROM record GROUP BY item
we can also force datatype inside the calculating, for example, see
int(acc.counter+1)
below:IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)
来源:https://stackoverflow.com/questions/64660047/how-to-use-windowing-functions-efficiently-to-decide-next-n-number-of-rows-based