PySpark row-wise function composition

后端 未结 3 624
生来不讨喜
生来不讨喜 2020-12-12 15:36

As a simplified example, I have a dataframe \"df\" with columns \"col1,col2\" and I want to compute a row-wise maximum after applying a function to each column :

<         


        
相关标签:
3条回答
  • 2020-12-12 16:03

    I had a similar problem and found the solution in the answer to this stackoverflow question

    To pass multiple columns or a whole row to an UDF use a struct:

    from pyspark.sql.functions import udf, struct
    from pyspark.sql.types import IntegerType
    
    df = sqlContext.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b"))
    
    count_empty_columns = udf(lambda row: len([x for x in row if x == None]), IntegerType())
    
    new_df = df.withColumn("null_count", count_empty_columns(struct([df[x] for x in df.columns])))
    
    new_df.show()
    

    returns:

    +----+----+----------+
    |   a|   b|null_count|
    +----+----+----------+
    |null|null|         2|
    |   1|null|         1|
    |null|   2|         1|
    +----+----+----------+
    
    0 讨论(0)
  • 2020-12-12 16:05

    Below a useful code especially made to create any new column by simply calling a top-level business rule, completely isolated from the technical and heavy Spark's stuffs (no need to spend $ and to feel dependant of Databricks libraries anymore). My advice is, in your organization try to do things simply and cleanly in life, for the benefits of top-level data users:

    def createColumnFromRule(df, columnName, ruleClass, ruleName, inputColumns=None, inputValues=None, columnType=None):
        from pyspark.sql import functions as F
        from pyspark.sql import types as T
        def _getSparkClassType(shortType):
            defaultSparkClassType = "StringType"
            typesMapping = {
                "bigint"    : "LongType",
                "binary"    : "BinaryType",
                "boolean"   : "BooleanType",
                "byte"      : "ByteType",
                "date"      : "DateType",
                "decimal"   : "DecimalType",
                "double"    : "DoubleType",
                "float"     : "FloatType",
                "int"       : "IntegerType",
                "integer"   : "IntegerType",
                "long"      : "LongType",
                "numeric"   : "NumericType",
                "string"    : defaultSparkClassType,
                "timestamp" : "TimestampType"
            }
            sparkClassType = None
            try:
                sparkClassType = typesMapping[shortType]
            except:
                sparkClassType = defaultSparkClassType
            return sparkClassType
        if (columnType != None): sparkClassType = _getSparkClassType(columnType)
        else: sparkClassType = "StringType"
        aUdf = eval("F.udf(ruleClass." + ruleName + ", T." + sparkClassType + "())")
        columns = None
        values = None
        if (inputColumns != None): columns = F.struct([df[column] for column in inputColumns])
        if (inputValues != None): values = F.struct([F.lit(value) for value in inputValues])
        # Call the rule
        if (inputColumns != None and inputValues != None): df = df.withColumn(columnName, aUdf(columns, values))
        elif (inputColumns != None): df = df.withColumn(columnName, aUdf(columns, F.lit(None)))
        elif (inputValues != None): df = df.withColumn(columnName, aUdf(F.lit(None), values))
        # Create a Null column otherwise
        else:
            if (columnType != None):
                df = df.withColumn(columnName, F.lit(None).cast(columnType))
            else:
                df = df.withColumn(columnName, F.lit(None))
        # Return the resulting dataframe
        return df
    

    Usage example:

    # Define your business rule (you can get columns and values)
    class CustomerRisk:
        def churnRisk(self, columns=None, values=None):
            isChurnRisk = False
            # ... Rule implementation starts here
            if (values != None):
                if (values[0] == "FORCE_CHURN=true"): isChurnRisk = True
            if (isChurnRisk == False and columns != None):
                if (columns["AGE"]) <= 25): isChurnRisk = True
            # ...
            return isChurnRisk
    
    # Execute the rule, it will create your new column in one line of code, that's all, easy isn't ?
    # And look how to pass columns and values, it's really easy !
    df = createColumnFromRule(df, columnName="CHURN_RISK", ruleClass=CustomerRisk(), ruleName="churnRisk", columnType="boolean", inputColumns=["NAME", "AGE", "ADDRESS"], inputValues=["FORCE_CHURN=true", "CHURN_RISK=100%"])
    
    0 讨论(0)
  • 2020-12-12 16:06

    UserDefinedFunction is throwing error while accepting UDFs as their arguments.

    You can modify the max_udf like below to make it work.

    df = sc.parallelize([(1, 2), (3, 0)]).toDF(["col1", "col2"])
    
    max_udf = udf(lambda x, y: max(x + 1, y + 1), IntegerType())
    
    df2 = df.withColumn("result", max_udf(df.col1, df.col2))
    

    Or

    def f_udf(x):
        return (x + 1)
    
    max_udf = udf(lambda x, y: max(x, y), IntegerType())
    ## f_udf=udf(f, IntegerType())
    
    df2 = df.withColumn("result", max_udf(f_udf(df.col1), f_udf(df.col2)))
    

    Note:

    The second approach is valid if and only if internal functions (here f_udf) generate valid SQL expressions.

    It works here because f_udf(df.col1) and f_udf(df.col2) are evaluated as Column<b'(col1 + 1)'> and Column<b'(col2 + 1)'> respectively, before being passed to max_udf. It wouldn't work with arbitrary function.

    It wouldn't work if we try for example something like this:

    from math import exp
    
    df.withColumn("result", max_udf(exp(df.col1), exp(df.col2)))
    
    0 讨论(0)
提交回复
热议问题