Scala - Spark In Dataframe retrieve, for row, column name with have max value

匿名 (未验证) 提交于 2019-12-03 08:59:04

问题:

I have a DataFrame:

name     column1  column2  column3  column4 first    2        1        2.1      5.4 test     1.5      0.5      0.9      3.7 choose   7        2.9      9.1      2.5 

I want a new dataframe with a column with contain, the column name with have max value for row :

| name   | max_column | |--------|------------| | first  | column4    | | test   | column4    | | choose | column3    | 

Thank you very much for support.

回答1:

There might some better way of writing UDF. But this could be the working solution

val spark: SparkSession = SparkSession.builder.master("local").getOrCreate  //implicits for magic functions like .toDf import spark.implicits._  import org.apache.spark.sql.functions.udf  //We have hard code number of params as UDF don't support variable number of args val maxval = udf((c1: Double, c2: Double, c3: Double, c4: Double) =>   if(c1 >= c2 && c1 >= c3 && c1 >= c4)     "column1"   else if(c2 >= c1 && c2 >= c3 && c2 >= c4)     "column2"   else if(c3 >= c1 && c3 >= c2 && c3 >= c4)     "column3"   else     "column4" )  //create schema class case class Record(name: String,                      column1: Double,                      column2: Double,                      column3: Double,                      column4: Double)  val df = Seq(   Record("first", 2.0, 1, 2.1, 5.4),   Record("test", 1.5, 0.5, 0.9, 3.7),   Record("choose", 7, 2.9, 9.1, 2.5) ).toDF();  df.withColumn("max_column", maxval($"column1", $"column2", $"column3", $"column4"))   .select("name", "max_column").show 

Output

+------+----------+ |  name|max_column| +------+----------+ | first|   column4| |  test|   column4| |choose|   column3| +------+----------+ 


回答2:

You get the job done making a detour to an RDD and using 'getValuesMap'.

val dfIn = Seq(   ("first", 2.0, 1., 2.1, 5.4),   ("test", 1.5, 0.5, 0.9, 3.7),   ("choose", 7., 2.9, 9.1, 2.5) ).toDF("name","column1","column2","column3","column4") 

The simple solution is

val dfOut = dfIn.rdd   .map(r => (        r.getString(0),        r.getValuesMap[Double](r.schema.fieldNames.filter(_!="name"))      ))   .map{case (n,m) => (n,m.maxBy(_._2)._1)}   .toDF("name","max_column") 

But if you want to take back all columns from the original dataframe (like in Scala/Spark dataframes: find the column name corresponding to the max), you have to play a bit with merging rows and extending the schema

import org.apache.spark.sql.types.{StructType,StructField,StringType} import org.apache.spark.sql.Row val dfOut = sqlContext.createDataFrame(   dfIn.rdd     .map(r => (r, r.getValuesMap[Double](r.schema.fieldNames.drop(1))))     .map{case (r,m) => Row.merge(r,(Row(m.maxBy(_._2)._1)))},   dfIn.schema.add(StructField("max_column",StringType)) ) 


回答3:

I want post my final solution:

val finalDf = originalDf.withColumn("name", maxValAsMap(keys, values)).select("cookie_id", "max_column")  val maxValAsMap = udf((keys: Seq[String], values: Seq[Any]) => {      val valueMap:Map[String,Double] = (keys zip values).filter( _._2.isInstanceOf[Double] ).map{       case (x,y) => (x, y.asInstanceOf[Double])     }.toMap      if (valueMap.isEmpty) "not computed" else valueMap.maxBy(_._2)._1   }) 

It's work very fast.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!