How to return a “Tuple type” in a UDF in PySpark?

后端 未结 3 837
轻奢々
轻奢々 2020-12-24 04:34

All the data types in pyspark.sql.types are:

__all__ = [
    \"DataType\", \"NullType\", \"StringType\", \"BinaryType\", \"BooleanType\", \"DateType\",
    \         


        
相关标签:
3条回答
  • 2020-12-24 04:53

    There is no such thing as a TupleType in Spark. Product types are represented as structs with fields of specific type. For example if you want to return an array of pairs (integer, string) you can use schema like this:

    from pyspark.sql.types import *
    
    schema = ArrayType(StructType([
        StructField("char", StringType(), False),
        StructField("count", IntegerType(), False)
    ]))
    

    Example usage:

    from pyspark.sql.functions import udf
    from collections import Counter
    
    char_count_udf = udf(
        lambda s: Counter(s).most_common(),
        schema
    )
    
    df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"])
    
    df.select("*", char_count_udf(df["value"])).show(2, False)
    
    ## +---+-----+-------------------------+
    ## |id |value|PythonUDF#<lambda>(value)|
    ## +---+-----+-------------------------+
    ## |1  |foo  |[[o,2], [f,1]]           |
    ## |2  |bar  |[[r,1], [a,1], [b,1]]    |
    ## +---+-----+-------------------------+
    
    0 讨论(0)
  • 2020-12-24 04:53

    Stackoverflow keeps directing me to this question, so I guess I'll add some info here.

    Returning simple types from UDF:

    from pyspark.sql.types import *
    from pyspark.sql import functions as F
    
    def get_df():
      d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
      df = sqlContext.createDataFrame(d, ['x', 'y'])
      return df
    
    df = get_df()
    df.show()
    
    # +---+---+
    # |  x|  y|
    # +---+---+
    # |0.0|0.0|
    # |0.0|3.0|
    # |1.0|6.0|
    # |1.0|9.0|
    # +---+---+
    
    func = udf(lambda x: str(x), StringType())
    df = df.withColumn('y_str', func('y'))
    
    func = udf(lambda x: int(x), IntegerType())
    df = df.withColumn('y_int', func('y'))
    
    df.show()
    
    # +---+---+-----+-----+
    # |  x|  y|y_str|y_int|
    # +---+---+-----+-----+
    # |0.0|0.0|  0.0|    0|
    # |0.0|3.0|  3.0|    3|
    # |1.0|6.0|  6.0|    6|
    # |1.0|9.0|  9.0|    9|
    # +---+---+-----+-----+
    
    df.printSchema()
    
    # root
    #  |-- x: double (nullable = true)
    #  |-- y: double (nullable = true)
    #  |-- y_str: string (nullable = true)
    #  |-- y_int: integer (nullable = true)
    

    When integers are not enough:

    df = get_df()
    
    func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
    df = df.withColumn('list', func('y'))
    
    func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
               MapType(FloatType(), StringType()))
    df = df.withColumn('map', func('y'))
    
    df.show()
    # +---+---+--------------------+--------------------+
    # |  x|  y|                list|                 map|
    # +---+---+--------------------+--------------------+
    # |0.0|0.0|                  []|               Map()|
    # |0.0|3.0|           [0, 0, 0]|Map(2.0 -> 2, 0.0...|
    # |1.0|6.0|  [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
    # |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
    # +---+---+--------------------+--------------------+
    
    df.printSchema()
    # root
    #  |-- x: double (nullable = true)
    #  |-- y: double (nullable = true)
    #  |-- list: array (nullable = true)
    #  |    |-- element: integer (containsNull = true)
    #  |-- map: map (nullable = true)
    #  |    |-- key: float
    #  |    |-- value: string (valueContainsNull = true)
    

    Returning complex datatypes from UDF:

    df = get_df()
    df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
    df.show()
    
    # +---+----------+
    # |  x|       y[]|
    # +---+----------+
    # |0.0|[0.0, 3.0]|
    # |1.0|[9.0, 6.0]|
    # +---+----------+
    
    schema = StructType([
        StructField("min", FloatType(), True),
        StructField("size", IntegerType(), True),
        StructField("edges",  ArrayType(FloatType()), True),
        StructField("val_to_index",  MapType(FloatType(), IntegerType()), True)
        # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))
    
    ])
    
    def func(values):
      mn = min(values)
      size = len(values)
      lst = sorted(values)[::-1]
      val_to_index = {x: i for i, x in enumerate(values)}
      return (mn, size, lst, val_to_index)
    
    func = udf(func, schema)
    dff = df.select('*', func('y[]').alias('complex_type'))
    dff.show(10, False)
    
    # +---+----------+------------------------------------------------------+
    # |x  |y[]       |complex_type                                          |
    # +---+----------+------------------------------------------------------+
    # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
    # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
    # +---+----------+------------------------------------------------------+
    
    dff.printSchema()
    
    # +---+----------+------------------------------------------------------+
    # |x  |y[]       |complex_type                                          |
    # +---+----------+------------------------------------------------------+
    # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
    # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
    # +---+----------+------------------------------------------------------+
    

    Passing multiple arguments to a UDF:

    df = get_df()
    func = udf(lambda arr: arr[0]*arr[1],FloatType())
    df = df.withColumn('x*y', func(F.array('x', 'y')))
    
        # +---+---+---+
        # |  x|  y|x*y|
        # +---+---+---+
        # |0.0|0.0|0.0|
        # |0.0|3.0|0.0|
        # |1.0|6.0|6.0|
        # |1.0|9.0|9.0|
        # +---+---+---+
    

    The code is purely for demo purposes, all above transformation are available in Spark code and would yield much better performance. As @zero323 in the comment above, UDFs should generally be avoided in pyspark; returning complex types should make you think about simplifying your logic.

    0 讨论(0)
  • 2020-12-24 05:09

    For the scala version instead of python. version 2.4

    import org.apache.spark.sql.types._
    
    val testschema : StructType = StructType(
        StructField("number", IntegerType) ::
        StructField("Array",  ArrayType(StructType(StructField("cnt_rnk", IntegerType) :: StructField("comp", StringType) :: Nil))) :: 
        StructField("comp", StringType):: Nil)
    

    The tree structure looks like this.

    testschema.printTreeString
    root
     |-- number: integer (nullable = true)
     |-- Array: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- cnt_rnk: integer (nullable = true)
     |    |    |-- corp_id: string (nullable = true)
     |-- comp: string (nullable = true)
    
    0 讨论(0)
提交回复
热议问题