Collect rows as list with group by apache spark

前端 未结 2 910
深忆病人
深忆病人 2020-12-30 08:16

I have a particular use case where I have multiple rows for same customer where each row object looks like:

root
 -c1: BigInt
 -c2: String
 -c3: Double
 -c4:         


        
相关标签:
2条回答
  • 2020-12-30 08:55

    Instead of array you can use struct function to combine the columns and use groupBy and collect_list aggregation function as

    import org.apache.spark.sql.functions._
    df.withColumn("combined", struct("c1","c2","c3","c4","c5"))
        .groupBy("c1").agg(collect_list("combined").as("combined_list"))
        .show(false)
    

    so that you have grouped dataset with schema as

    root
     |-- c1: integer (nullable = false)
     |-- combined_list: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- c1: integer (nullable = false)
     |    |    |-- c2: string (nullable = true)
     |    |    |-- c3: string (nullable = true)
     |    |    |-- c4: string (nullable = true)
     |    |    |-- c5: map (nullable = true)
     |    |    |    |-- key: string
     |    |    |    |-- value: integer (valueContainsNull = false)
    

    I hope the answer is helpful

    0 讨论(0)
  • 2020-12-30 09:02

    If you want the result consisting of collections of Rows, consider transforming to a RDD as follows:

    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.Row
    
    def df = Seq(
        (BigInt(10), "x", 1.0, 2.0, Map("a"->1, "b"->2)),
        (BigInt(10), "y", 3.0, 4.0, Map("c"->3)),
        (BigInt(20), "z", 5.0, 6.0, Map("d"->4, "e"->5))
      ).
      toDF("c1", "c2", "c3", "c4", "c5").
      // as[(BigInt, String, Double, Double, Map[String, Int])]
    
    df.rdd.map(r => (r.getDecimal(0), r)).groupByKey.collect
    // res1: Array[(java.math.BigDecimal, Iterable[org.apache.spark.sql.Row])] = Array(
    //   (10,CompactBuffer([10,x,1.0,2.0,Map(a -> 1, b -> 2)], [10,y,3.0,4.0,Map(c -> 3)])),
    //   (20,CompactBuffer([20,z,5.0,6.0,Map(d -> 4, e -> 5)]))
    // )
    

    Or, if you're good with collections of struct-type rows in a DataFrame, here's an alternative approach:

    val cols = ds.columns
    
    df.groupBy("c1").agg(collect_list(struct(cols.head, cols.tail: _*)).as("row_list")).
      show(false)
    // +---+----------------------------------------------------------------+
    // |c1 |row_list                                                        |
    // +---+----------------------------------------------------------------+
    // |20 |[[20,z,5.0,6.0,Map(d -> 4, e -> 5)]]                            |
    // |10 |[[10,x,1.0,2.0,Map(a -> 1, b -> 2)], [10,y,3.0,4.0,Map(c -> 3)]]|
    // +---+----------------------------------------------------------------+
    
    0 讨论(0)
提交回复
热议问题