How to calculate mean and standard deviation given a PySpark DataFrame?

后端 未结 3 1043
礼貌的吻别
礼貌的吻别 2021-02-07 14:29

I have PySpark DataFrame (not pandas) called df that is quite large to use collect(). Therefore the below-given code is not efficient.

3条回答
  •  灰色年华
    2021-02-07 15:15

    You can use the built in functions to get aggregate statistics. Here's how to get mean and standard deviation.

    from pyspark.sql.functions import mean as _mean, stddev as _stddev, col
    
    df_stats = df.select(
        _mean(col('columnName')).alias('mean'),
        _stddev(col('columnName')).alias('std')
    ).collect()
    
    mean = df_stats[0]['mean']
    std = df_stats[0]['std']
    

    Note that there are three different standard deviation functions. From the docs the one I used (stddev) returns the following:

    Aggregate function: returns the unbiased sample standard deviation of the expression in a group

    You could use the describe() method as well:

    df.describe().show()
    

    Refer to this link for more info: pyspark.sql.functions

    UPDATE: This is how you can work through the nested data.

    Use explode to extract the values into separate rows, then call mean and stddev as shown above.

    Here's a MWE:

    from pyspark.sql.types import IntegerType
    from pyspark.sql.functions import explode, col, udf, mean as _mean, stddev as _stddev
    
    # mock up sample dataframe
    df = sqlCtx.createDataFrame(
        [(680, [[691,1], [692,5]]), (685, [[691,2], [692,2]]), (684, [[691,1], [692,3]])],
        ["product_PK", "products"]
    )
    
    # udf to get the "score" value - returns the item at index 1
    get_score = udf(lambda x: x[1], IntegerType())
    
    # explode column and get stats
    df_stats = df.withColumn('exploded', explode(col('products')))\
        .withColumn('score', get_score(col('exploded')))\
        .select(
            _mean(col('score')).alias('mean'),
            _stddev(col('score')).alias('std')
        )\
        .collect()
    
    mean = df_stats[0]['mean']
    std = df_stats[0]['std']
    
    print([mean, std])
    

    Which outputs:

    [2.3333333333333335, 1.505545305418162]
    

    You can verify that these values are correct using numpy:

    vals = [1,5,2,2,1,3]
    print([np.mean(vals), np.std(vals, ddof=1)])
    

    Explanation: Your "products" column is a list of lists. Calling explode will make a new row for each element of the outer list. Then grab the "score" value from each of the exploded rows, which you have defined as the second element in a 2-element list. Finally, call the aggregate functions on this new column.

提交回复
热议问题