Apply a function to groupBy data with pyspark

前端 未结 2 1817
故里飘歌
故里飘歌 2021-01-04 19:10

I\'m trying to get word counts from a csv when grouping on another column. My csv has three columns: id, message and user_id. I read this in and then split the message and s

2条回答
  •  北海茫月
    2021-01-04 19:46

    A natural approach could be to group the words into one list, and then use the python function Counter() to generate word counts. For both steps we'll use udf's. First, the one that will flatten the nested list resulting from collect_list() of multiple arrays:

    unpack_udf = udf(
        lambda l: [item for sublist in l for item in sublist]
    )
    

    Second, one that generates the word count tuples, or in our case struct's:

    from pyspark.sql.types import *
    from collections import Counter
    
    # We need to specify the schema of the return object
    schema_count = ArrayType(StructType([
        StructField("word", StringType(), False),
        StructField("count", IntegerType(), False)
    ]))
    
    count_udf = udf(
        lambda s: Counter(s).most_common(), 
        schema_count
    )
    

    Putting it all together:

    from pyspark.sql.functions import collect_list
    
    (df.groupBy("id")
     .agg(collect_list("message").alias("message"))
     .withColumn("message", unpack_udf("message"))
     .withColumn("message", count_udf("message"))).show(truncate = False)
    +-----------------+------------------------------------------------------+
    |id               |message                                               |
    +-----------------+------------------------------------------------------+
    |10100718890699676|[[oecd,1], [the,1], [with,1], [at,1]]                 |
    |10100720363468236|[[what,3], [me,1], [sad,1], [to,1], [does,1], [the,1]]|
    +-----------------+------------------------------------------------------+
    

    Data:

    df = sc.parallelize([(10100720363468236,["what", "sad", "to", "me"]),
                         (10100720363468236,["what", "what", "does", "the"]),
                         (10100718890699676,["at", "the", "oecd", "with"])]).toDF(["id", "message"])
    

提交回复
热议问题