How to get the number of elements in partition?

后端 未结 3 742
清歌不尽
清歌不尽 2020-12-05 11:14

Is there any way to get the number of elements in a spark RDD partition, given the partition ID? Without scanning the entire partition.

Something like this:

相关标签:
3条回答
  • 2020-12-05 11:38

    PySpark:

    num_partitions = 20000
    a = sc.parallelize(range(int(1e6)), num_partitions)
    l = a.glom().map(len).collect()  # get length of each partition
    print(min(l), max(l), sum(l)/len(l), len(l))  # check if skewed
    

    Spark/scala:

    val numPartitions = 20000
    val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
    val l = a.glom().map(_.length).collect()  # get length of each partition
    print(l.min, l.max, l.sum/l.length, l.length)  # check if skewed
    

    The same is possible for a dataframe, not just for an RDD. Just add DF.rdd.glom... into the code above.

    Notice that glom() converts elements of each partition into a list, so it's memory-intensive. A less memory-intensive version (pyspark version only):

    import statistics 
    
    def get_table_partition_distribution(table_name: str):
    
        def get_partition_len (iterator):
            yield sum(1 for _ in iterator)
    
        l = spark.table(table_name).rdd.mapPartitions(get_partition_len, True).collect()  # get length of each partition
        num_partitions = len(l)
        min_count = min(l)
        max_count = max(l)
        avg_count = sum(l)/num_partitions
        stddev = statistics.stdev(l)
        print(f"{table_name} each of {num_partitions} partition's counts: min={min_count:,} avg±stddev={avg_count:,.1f} ±{stddev:,.1f} max={max_count:,}")
    
    
    get_table_partition_distribution('someTable')
    
    

    outputs something like

    someTable each of 1445 partition's counts: min=1,201,201 avg±stddev=1,202,811.6 ±21,783.4 max=2,030,137

    0 讨论(0)
  • 2020-12-05 11:47

    pzecevic's answer works, but conceptually there's no need to construct an array and then convert it to an iterator. I would just construct the iterator directly and then get the counts with a collect call.

    rdd.mapPartitions(iter => Iterator(iter.size), true).collect()
    

    P.S. Not sure if his answer is actually doing more work since Iterator.apply will likely convert its arguments into an array.

    0 讨论(0)
  • 2020-12-05 11:52

    The following gives you a new RDD with elements that are the sizes of each partition:

    rdd.mapPartitions(iter => Array(iter.size).iterator, true) 
    
    0 讨论(0)
提交回复
热议问题