Calculate the standard deviation of grouped data in a Spark DataFrame

前端 未结 2 465
终归单人心
终归单人心 2020-12-05 05:27

I have user logs that I have taken from a csv and converted into a DataFrame in order to leverage the SparkSQL querying features. A single user will create numerous entries

相关标签:
2条回答
  • 2020-12-05 05:41

    Spark 1.6+

    You can use stddev_pop to compute population standard deviation and stddev / stddev_samp to compute unbiased sample standard deviation:

    import org.apache.spark.sql.functions.{stddev_samp, stddev_pop}
    
    selectedData.groupBy($"user").agg(stdev_pop($"duration"))
    

    Spark 1.5 and below (The original answer):

    Not so pretty and biased (same as the value returned from describe) but using formula:

    you can do something like this:

    import org.apache.spark.sql.functions.sqrt
    
    selectedData
        .groupBy($"user")
        .agg((sqrt(
            avg($"duration" * $"duration") -
            avg($"duration") * avg($"duration")
         )).alias("duration_sd"))
    

    You can of course create a function to reduce the clutter:

    import org.apache.spark.sql.Column
    def mySd(col: Column): Column = {
        sqrt(avg(col * col) - avg(col) * avg(col))
    }
    
    df.groupBy($"user").agg(mySd($"duration").alias("duration_sd"))
    

    It is also possible to use Hive UDF:

    df.registerTempTable("df")
    sqlContext.sql("""SELECT user, stddev(duration)
                      FROM df
                      GROUP BY user""")
    

    Source of the image: https://en.wikipedia.org/wiki/Standard_deviation

    0 讨论(0)
  • 2020-12-05 05:55

    The accepted code does not compile, as it has a typo (as pointed out by MRez). The snippet below works and is tested.

    For Spark 2.0+ :

    import org.apache.spark.sql.functions._
    val _avg_std = df.groupBy("user").agg(
            avg(col("duration").alias("avg")),
            stddev(col("duration").alias("stdev")),
            stddev_pop(col("duration").alias("stdev_pop")),
            stddev_samp(col("duration").alias("stdev_samp"))
            )
    
    0 讨论(0)
提交回复
热议问题