How to write Pyspark UDAF on multiple columns?

前端 未结 1 786
余生分开走
余生分开走 2021-01-06 17:20

I have the following data in a pyspark dataframe called end_stats_df:

values     start    end    cat1   cat2
10          1        2      A               


        
1条回答
  •  有刺的猬
    2021-01-06 17:59

    Pyspark do not support UDAF directly, so we have to do aggregation manually.

    from pyspark.sql import functions as f
    
    def func(values, cat1, cat2):
        n = len(set(cat1 + cat2))
        return sorted(values)[n - 2]
    
    
    df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', sep='\t', header=True)
    df = df.groupBy(df['start'], df['end']).agg(f.collect_list(df['values']).alias('values'),
                                                f.collect_set(df['cat1']).alias('cat1'),
                                                f.collect_set(df['cat2']).alias('cat2'))
    df = df.select(df['start'], df['end'], f.UserDefinedFunction(func, StringType())(df['values'], df['cat1'], df['cat2']))
    

    0 讨论(0)
提交回复
热议问题