I\'m trying to filter an RDD of tuples to return the largest N tuples based on key values. I need the return format to be an RDD.
So the RDD:
[(4, \'
With RDD
A quick but not particularly efficient solution is to follow sortByKey
use zipWithIndex
and filter
:
n = 3
rdd = sc.parallelize([(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')])
rdd.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
If n is relatively small compared to RDD size a little bit more efficient approach is to avoid full sort:
import heapq
def key(kv):
return kv[0]
top_per_partition = rdd.mapPartitions(lambda iter: heapq.nlargest(n, iter, key))
top_per_partition.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
If keys are much smaller than values and order of final output doesn't matter then filter
approach can work just fine:
keys = rdd.keys()
identity = lambda x: x
offset = (keys
.mapPartitions(lambda iter: heapq.nlargest(n, iter))
.sortBy(identity)
.zipWithIndex()
.filter(lambda xi: xi[1] < n)
.keys()
.max())
rdd.filter(lambda kv: kv[0] <= offset)
Also it won't keep exact n values in case of ties.
With DataFrames
You can just orderBy
and limit
:
from pyspark.sql.functions import col
rdd.toDF().orderBy(col("_1").desc()).limit(n)
A less effort approach since you only want to convert take(N)
results to new RDD.
sc.parallelize(yourSortedRdd.take(Nth))