Find mean of pyspark array

后端 未结 2 885
醉酒成梦
醉酒成梦 2021-01-17 18:31

In pyspark, I have a variable length array of doubles for which I would like to find the mean. However, the average function requires a single numeric type.

Is th

相关标签:
2条回答
  • 2021-01-17 19:04

    In the recent Spark versions (2.4 or later) the most efficient solution is to use aggregate higher order function:

    from pyspark.sql.functions import expr
    
    query = """aggregate(
        `{col}`,
        CAST(0.0 AS double),
        (acc, x) -> acc + x,
        acc -> acc / size(`{col}`)
    ) AS  `avg_{col}`""".format(col="longitude")
    
    df.selectExpr("*", query).show()
    
    +--------------------+------------------+
    |           longitude|     avg_longitude|
    +--------------------+------------------+
    |      [-80.9, -82.9]|             -81.9|
    |[-82.92, -82.93, ...|-82.93166666666667|
    |    [-82.93, -82.93]|            -82.93|
    +--------------------+------------------+
    

    See also Spark Scala row-wise average by handling null

    0 讨论(0)
  • 2021-01-17 19:20

    In your case, your options are use explode or a udf. As you've noted, explode is unnecessarily expensive. Thus, a udf is the way to go.

    You can write your own function to take the mean of a list of numbers, or just piggy back off of numpy.mean. If you use numpy.mean, you'll have to cast the result to a float (because spark doesn't know how to handle numpy.float64s).

    import numpy as np
    from pyspark.sql.functions import udf
    from pyspark.sql.types import FloatType
    
    array_mean = udf(lambda x: float(np.mean(x)), FloatType())
    df.select(array_mean("longitude").alias("avg")).show()
    #+---------+
    #|      avg|
    #+---------+
    #|    -81.9|
    #|-82.93166|
    #|   -82.93|
    #+---------+
    
    0 讨论(0)
提交回复
热议问题