Combine PySpark DataFrame ArrayType fields into single ArrayType field

后端 未结 2 1820
既然无缘
既然无缘 2020-12-05 14:34

I have a PySpark DataFrame with 2 ArrayType fields:

>>>df
DataFrame[id: string, tokens: array, bigrams: array]
>>&         


        
相关标签:
2条回答
  • 2020-12-05 15:11

    In Spark 2.4.0 (2.3 on Databricks platform) you can do it natively in the DataFrame API using the concat function. In your example you could do this:

    from pyspark.sql.functions import col, concat
    
    df.withColumn('tokens_bigrams', concat(col('tokens'), col('bigrams')))
    

    Here is the related jira.

    0 讨论(0)
  • 2020-12-05 15:22

    Spark >= 2.4

    You can use concat function (SPARK-23736):

    from pyspark.sql.functions import col, concat 
    
    df.select(concat(col("tokens"), col("tokens_bigrams"))).show(truncate=False)
    
    # +---------------------------------+                                             
    # |concat(tokens, tokens_bigrams)   |
    # +---------------------------------+
    # |[one, two, two, one two, two two]|
    # |null                             |
    # +---------------------------------+
    

    To keep data when one of the values is NULL you can coalesce with array:

    from pyspark.sql.functions import array, coalesce      
    
    df.select(concat(
        coalesce(col("tokens"), array()),
        coalesce(col("tokens_bigrams"), array())
    )).show(truncate = False)
    
    # +--------------------------------------------------------------------+
    # |concat(coalesce(tokens, array()), coalesce(tokens_bigrams, array()))|
    # +--------------------------------------------------------------------+
    # |[one, two, two, one two, two two]                                   |
    # |[three]                                                             |
    # +--------------------------------------------------------------------+
    

    Spark < 2.4

    Unfortunately to concatenate array columns in general case you'll need an UDF, for example like this:

    from itertools import chain
    from pyspark.sql.functions import col, udf
    from pyspark.sql.types import *
    
    
    def concat(type):
        def concat_(*args):
            return list(chain.from_iterable((arg if arg else [] for arg in args)))
        return udf(concat_, ArrayType(type))
    

    which can be used as:

    df = spark.createDataFrame(
        [(["one", "two", "two"], ["one two", "two two"]), (["three"], None)], 
        ("tokens", "tokens_bigrams")
    )
    
    concat_string_arrays = concat(StringType())
    df.select(concat_string_arrays("tokens", "tokens_bigrams")).show(truncate=False)
    
    # +---------------------------------+
    # |concat_(tokens, tokens_bigrams)  |
    # +---------------------------------+
    # |[one, two, two, one two, two two]|
    # |[three]                          |
    # +---------------------------------+
    
    0 讨论(0)
提交回复
热议问题