1. 需求
在处理时间序列或者是有序数据时候,经常会越到这样的情形:1、求客户最近一个月的平均消费金额;2、求客户最近一个月的消费次数;3、求与客户在最近一个月内发生大额消费的客户数量
上述问题中,问题1与问题2是典型的求指定客户在指定时间段的行为统计值,可以通过客户进行分组然后过滤统计即可,而问题三则是求解与该客户在指定时间段内发生某种行为的客户数量,即没有具体的聚合Key,从而不同按照问题1与问题2进行分组的方式求解,而通过类似时间序列中求一个序列的移动平均方法来求解
针对上述两类情形,下面分别提供了两个函数,这个两个函数,一个是aggregateByKey,根据key聚合,并对聚合后的每一条记录采用窗口函数获取聚合数据,即该函数是对每一个key对应的value进行移动聚合操作。另一个是aggregateByAll,根据key进行排序,通过窗口函数获取结果进行聚合。这两个函数都是根据窗口函数进行聚合操作,但不同点在于作用范围不同。
窗口函数与聚合函数用户定义可以自由定义,通过定义不同的窗口函数与聚合函数,可以实现不同的移动逻辑以及聚合运算。
2. 实现思路
2.1 aggregateByKey
aggregateByKey实现起来比较简单,因为一般情况先,单个key对应的value不会很大(单个executor的内存是可以装得下的),所有可以通过简单的reduceByKey把所有相同key对应的value混洗到相同的分区,然后对分区内的每条数据通过窗口函数移动,把移动获取到的数据根据聚合函数执行组合操作即可
2.2 aggregateByAll
aggregateByAll由于要对整个RDD内的数据进行移动聚合,所有不能够像aggregateByKey那样把待聚合数据放在一起(因为是所有,executor一般装不下),所以要通过分区分别操作。简要步骤如下:
-
- 对RDD通过RangePartitioner进行分区,使得分区间的数据是有序的
-
- 针对每一个分区,将分区内的数据进行排序,将排序后的数据头部与尾部满足窗口滑动的数据分别shuffle到前后相邻的分区。这一步是为了确保分区内满足移动窗口的数据在当前分区中。
-
- 按照各个分区求指定窗口数据的聚合操作即可
技术难点在于移动窗口是跨分区时候如何解决?即当前数据需要聚合的数据在另外一个分区中。
2. 源码
package com.jiamz.aidp.bigdata.utils
import com.jiamz.aidp.bigdata.SparkHelper
import org.apache.spark.rdd.RDD
import org.apache.spark.{HashPartitioner, RangePartitioner}
import scala.reflect.ClassTag
/**
* Created by zhoujiamu on 2020/3/11.
*/
object MovingAggregate {
private val lastPart: Byte = -1
private val currPart: Byte = 0
private val nextPart: Byte = 1
/**
* 根据key对value进行聚合操作, 与reduceByKey不同的是, 该聚合操作是对同一key对应的value根据窗口函数进行滑动,
* 然后针对窗口移动获取的结果进行聚合操作
* @param rdd 待聚合key-value形式的RDD
* @param winFunc 移动窗口函数
* @param aggFunc 聚合函数
* @tparam K key类型
* @tparam V value类型
* @tparam U 聚合结果类型
* @return 移动聚合结果RDD
*/
def aggregateByKey[K: ClassTag, V: ClassTag, U: ClassTag](rdd: RDD[(K, V)],
winFunc: (V, V) => Boolean,
aggFunc: Seq[V] => U): RDD[(K, (V, U))] ={
val result = rdd.mapPartitions(iter => iter.map{case(k, v) => k -> Seq(v)})
.reduceByKey(_++_)
.mapPartitions(iter =>
iter.flatMap{case(k, seq) => {
val aggResult = seq.flatMap(s1 => {
val aggSeq = seq.filter(s2 => winFunc(s1, s2))
if (aggSeq.nonEmpty)
Iterator(s1 -> aggFunc(aggSeq))
else
Iterator.empty
})
aggResult.map(res => k -> res)
}}
)
result
}
/**
* 根据key对数据进行排序, 通过winFunc函数来滑动截取需要聚合的数据进行聚合操作,
* @param rdd 待聚合key-value形式的RDD
* @param winFunc 移动窗口函数
* @param aggFunc 聚合函数
* @tparam K key类型
* @tparam V value类型
* @tparam U 聚合结果类型
* @return 移动聚合结果RDD
*/
def aggregateByAll[K: Ordering: ClassTag, V: ClassTag, U: ClassTag](rdd: RDD[(K, V)],
winFunc: (K, K) => Boolean,
aggFunc: Seq[V] => U
): RDD[(K, (V, U))] ={
val partitioner = new RangePartitioner(rdd.getNumPartitions, rdd)
val newRdd = rdd.partitionBy(partitioner).cache()
/**
* 获取当前数据近邻的记录
* @param index 当前数据索引
* @param seq 数据列表
* @return 当前数据近邻记录
*/
def getNbrs(index: Int, seq: Seq[(K, V)]): Seq[(K, V)] ={
val center = seq(index)._1
val len = seq.length
var start = index
var end = index
while (end < len && winFunc(center, seq(end)._1)) end += 1
while (start >= 0 && winFunc(center, seq(start)._1)) start -= 1
seq.slice(start+1, end)
}
val numPartition = rdd.getNumPartitions
def getPartTail(pid: Int, seq: Seq[(K, V)]): Seq[(Int, ((K, V), Byte))] ={
val center = seq.last._1
val length = seq.length
var ind = length - 1
while (ind >= 0 && winFunc(center, seq(ind)._1)) ind -= 1
seq.slice(ind+1, length).map(data => (pid+1, (data, lastPart)))
}
def getPartHead(pid: Int, seq: Seq[(K, V)]): Seq[(Int, ((K, V), Byte))] ={
val center = seq.head._1
val length = seq.length
var ind = 1
while (ind < length && winFunc(center, seq(ind)._1)) ind += 1
seq.slice(0, ind).map(data => (pid-1, (data, nextPart)))
}
// 相邻分区的数据中符合窗口函数的需要进行copy到相邻分区中,使得对每一条数据的邻居数据(设定窗口内)都在同一分区
val rddWithShuffle = newRdd.mapPartitionsWithIndex((pid, iter) => {
if (numPartition == 1 || iter.isEmpty){
iter.map(data => (pid, (data, currPart)))
} else {
val seq = iter.toSeq.sortBy(_._1)
val moved = if (pid == 0){ // 第一个分区的尾部数据往后一个分区移动
getPartTail(pid, seq)
} else if (pid == numPartition-1){ // 最后一个分区的尾部数据往前一个分区移动
getPartHead(pid, seq)
} else { // 中间分区前后的数据都往相邻分区移动
getPartHead(pid, seq) ++ getPartTail(pid, seq)
}
val fixed = seq.map(data => (pid, (data, currPart)))
(fixed ++ moved).toIterator
}
}).partitionBy(new HashPartitioner(numPartition))
val aggregateResult = rddWithShuffle.mapPartitions(part => {
val seq = part.toSeq.sortBy(_._1)
val data = seq.map(_._2._1)
seq.zipWithIndex.filter(x => x._1._2._2.equals(currPart))
.map{case(_, i) => {
val center = seq(i)._2._1
val nbrs = getNbrs(i, data)
val aggRes = aggFunc(nbrs.map(_._2))
(center._1, (center._2, aggRes))
}}.toIterator
})
aggregateResult
}
def main(args: Array[String]): Unit = {
SparkHelper.setLogLevel("WARN")
val sc = SparkHelper.getSparkContext("MovingAggregate", "spark.master" -> "local")
val array = Array(
("id1", (1, 10.0)),
("id1", (2, 20.0)),
("id1", (3, 30.0)),
("id1", (7, 40.0)),
("id1", (8, 50.0)),
("id1", (9, 60.0)),
("id2", (1, 70.0)),
("id2", (2, 80.0)),
("id2", (3, 30.0)),
("id2", (4, 20.0)),
("id2", (7, 50.0)),
("id2", (9, 60.0))
)
/** 针对key做滑动平均 */
// 定义移动窗口函数
def winFunc(v1: (Int, Double), v2: (Int, Double)) = {
v2._1 - v1._1 <= 3 && v2._1 >= v1._1
}
// 定义聚合函数, 求平均
def aggFunc(seq: Seq[(Int, Double)]) = {
val values = seq.map(_._2)
values.sum / values.length
}
val rdd = sc.makeRDD(array)
val res = aggregateByKey(rdd, winFunc, aggFunc)
// 打印按照key计算移动平均结果
res.collect().foreach(println)
println("-"*50)
val array1 = Array(
(1, 10.0),
(2, 20.0),
(3, 30.0),
(7, 40.0),
(8, 50.0),
(9, 60.0),
(20, 70.0),
(22, 80.0),
(23, 30.0),
(31, 20.0),
(33, 50.0),
(36, 60.0)
)
val rdd1 = sc.makeRDD(array1, numSlices = 3)
def winFunc1(s1: Int, s2: Int) = {
math.abs(s2 - s1) <= 1
}
def aggFunc1(seq: Seq[Double]) = {
seq.sum / seq.length
}
val res1 = aggregateByAll(rdd1, winFunc1, aggFunc1)
println("-"*30)
// 打印按照key指定的窗口计算移动平均结果
res1.collect().foreach(println)
}
}
执行结果
(id1,((1,10.0),20.0))
(id1,((2,20.0),25.0))
(id1,((3,30.0),30.0))
(id1,((7,40.0),50.0))
(id1,((8,50.0),55.0))
(id1,((9,60.0),60.0))
(id2,((1,70.0),50.0))
(id2,((2,80.0),43.333333333333336))
(id2,((3,30.0),25.0))
(id2,((4,20.0),35.0))
(id2,((7,50.0),55.0))
(id2,((9,60.0),60.0))
--------------------------------------------------
------------------------------
(1,(10.0,15.0))
(2,(20.0,20.0))
(3,(30.0,25.0))
(7,(40.0,45.0))
(8,(50.0,50.0))
(9,(60.0,55.0))
(20,(70.0,70.0))
(22,(80.0,55.0))
(23,(30.0,55.0))
(31,(20.0,20.0))
(33,(50.0,50.0))
(36,(60.0,60.0))
来源:oschina
链接:https://my.oschina.net/u/3780646/blog/3192970