I\'m trying to get word counts from a csv when grouping on another column. My csv has three columns: id, message and user_id. I read this in and then split the message and s
A natural approach could be to group the words into one list, and then use the python function Counter()
to generate word counts. For both steps we'll use udf
's. First, the one that will flatten the nested list resulting from collect_list()
of multiple arrays:
unpack_udf = udf(
lambda l: [item for sublist in l for item in sublist]
)
Second, one that generates the word count tuples, or in our case struct
's:
from pyspark.sql.types import *
from collections import Counter
# We need to specify the schema of the return object
schema_count = ArrayType(StructType([
StructField("word", StringType(), False),
StructField("count", IntegerType(), False)
]))
count_udf = udf(
lambda s: Counter(s).most_common(),
schema_count
)
Putting it all together:
from pyspark.sql.functions import collect_list
(df.groupBy("id")
.agg(collect_list("message").alias("message"))
.withColumn("message", unpack_udf("message"))
.withColumn("message", count_udf("message"))).show(truncate = False)
+-----------------+------------------------------------------------------+
|id |message |
+-----------------+------------------------------------------------------+
|10100718890699676|[[oecd,1], [the,1], [with,1], [at,1]] |
|10100720363468236|[[what,3], [me,1], [sad,1], [to,1], [does,1], [the,1]]|
+-----------------+------------------------------------------------------+
Data:
df = sc.parallelize([(10100720363468236,["what", "sad", "to", "me"]),
(10100720363468236,["what", "what", "does", "the"]),
(10100718890699676,["at", "the", "oecd", "with"])]).toDF(["id", "message"])