PySpark - Adding a Column from a list of values using a UDF

前端 未结 5 1412
臣服心动
臣服心动 2021-01-05 00:32

I have to add column to a PySpark dataframe based on a list of values.

a= spark.createDataFrame([(\"Dog\", \"Cat\"), (\"Cat\", \"Dog\"), (\"Mouse\", \"Cat\"         


        
相关标签:
5条回答
  • 2021-01-05 01:10

    You can convert your rating into rdd

    rating = [5,4,1]
    ratingrdd = sc.parallelize(rating)
    

    And then convert your dataframe to rdd, attach each value of ratingrdd to rdd dataframe using zip and convert the zipped rdd to dataframe again

    sqlContext.createDataFrame(a.rdd.zip(ratingrdd).map(lambda x: (x[0][0], x[0][1], x[1])), ["Animal", "Enemy", "Rating"]).show()
    

    It should give you

    +------+-----+------+
    |Animal|Enemy|Rating|
    +------+-----+------+
    |   Dog|  Cat|     5|
    |   Cat|  Dog|     4|
    | Mouse|  Cat|     1|
    +------+-----+------+
    
    0 讨论(0)
  • 2021-01-05 01:16

    What you are trying to do does not work, because the rating list is in your driver's memory, whereas the a dataframe is in the executor's memory (the udf works on the executors too).

    What you need to do is add the keys to the ratings list, like so:

    ratings = [('Dog', 5), ('Cat', 4), ('Mouse', 1)]
    

    Then you create a ratings dataframe from the list and join both to get the new colum added:

    ratings_df = spark.createDataFrame(ratings, ['Animal', 'Rating'])
    new_df = a.join(ratings_df, 'Animal')
    
    0 讨论(0)
  • 2021-01-05 01:17

    Hope this helps!

    from pyspark.sql.functions import monotonically_increasing_id, row_number
    from pyspark.sql import Window
    
    #sample data
    a= sqlContext.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],
                                   ["Animal", "Enemy"])
    a.show()
    
    #convert list to a dataframe
    rating = [5,4,1]
    b = sqlContext.createDataFrame([(l,) for l in rating], ['Rating'])
    
    #add 'sequential' index and join both dataframe to get the final result
    a = a.withColumn("row_idx", row_number().over(Window.orderBy(monotonically_increasing_id())))
    b = b.withColumn("row_idx", row_number().over(Window.orderBy(monotonically_increasing_id())))
    
    final_df = a.join(b, a.row_idx == b.row_idx).\
                 drop("row_idx")
    final_df.show()
    

    Input:

    +------+-----+
    |Animal|Enemy|
    +------+-----+
    |   Dog|  Cat|
    |   Cat|  Dog|
    | Mouse|  Cat|
    +------+-----+
    

    Output is:

    +------+-----+------+
    |Animal|Enemy|Rating|
    +------+-----+------+
    |   Cat|  Dog|     4|
    |   Dog|  Cat|     5|
    | Mouse|  Cat|     1|
    +------+-----+------+
    
    0 讨论(0)
  • 2021-01-05 01:17

    As mentioned by @Tw UxTLi51Nus, if you can order the DataFrame, let's say, by Animal, without this changing your results, you can then do the following:

    def add_labels(indx):
        return rating[indx-1] # since row num begins from 1
    labels_udf = udf(add_labels, IntegerType())
    
    a = spark.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],["Animal", "Enemy"])
    a.createOrReplaceTempView('a')
    a = spark.sql('select row_number() over (order by "Animal") as num, * from a')
    
    a.show()
    
    
    +---+------+-----+
    |num|Animal|Enemy|
    +---+------+-----+
    |  1|   Dog|  Cat|
    |  2|   Cat|  Dog|
    |  3| Mouse|  Cat|
    +---+------+-----+
    
    new_df = a.withColumn('Rating', labels_udf('num'))
    new_df.show()
    +---+------+-----+------+
    |num|Animal|Enemy|Rating|
    +---+------+-----+------+
    |  1|   Dog|  Cat|     5|
    |  2|   Cat|  Dog|     4|
    |  3| Mouse|  Cat|     1|
    +---+------+-----+------+
    

    And then drop the num column:

    new_df.drop('num').show()
    +------+-----+------+
    |Animal|Enemy|Rating|
    +------+-----+------+
    |   Dog|  Cat|     5|
    |   Cat|  Dog|     4|
    | Mouse|  Cat|     1|
    +------+-----+------+
    

    Edit:

    Another - but perhaps ugly and a bit inefficient - way, if you cannot sort by a column, is to go back to rdd and do the following:

    a = spark.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],["Animal", "Enemy"])
    
    # or create the rdd from the start:
    # a = spark.sparkContext.parallelize([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")])
    
    a = a.rdd.zipWithIndex()
    a = a.toDF()
    a.show()
    
    +-----------+---+
    |         _1| _2|
    +-----------+---+
    |  [Dog,Cat]|  0|
    |  [Cat,Dog]|  1|
    |[Mouse,Cat]|  2|
    +-----------+---+
    
    a = a.select(bb._1.getItem('Animal').alias('Animal'), bb._1.getItem('Enemy').alias('Enemy'), bb._2.alias('num'))
    
    def add_labels(indx):
        return rating[indx] # indx here will start from zero
    
    labels_udf = udf(add_labels, IntegerType())
    
    new_df = a.withColumn('Rating', labels_udf('num'))
    
    new_df.show()
    
    +---------+--------+---+------+
    |Animal   |   Enemy|num|Rating|
    +---------+--------+---+------+
    |      Dog|     Cat|  0|     5|
    |      Cat|     Dog|  1|     4|
    |    Mouse|     Cat|  2|     1|
    +---------+--------+---+------+
    

    (I would not recommend this if you have much data)

    Hope this helps, good luck!

    0 讨论(0)
  • 2021-01-05 01:30

    I might be wrong, but I believe the accepted answer will not work. monotonically_increasing_id only guarantees that the ids will be unique and increasing, not that they will be consecutive. Hence using it on two different dataframes will likely create two very different columns, and the join will mostly return empty.

    taking inspiration from this answer https://stackoverflow.com/a/48211877/7225303 to a similar question, we could change the incorrect answer to:

    from pyspark.sql.window import Window as W
    from pyspark.sql import functions as F
    
    a= sqlContext.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],
                                   ["Animal", "Enemy"])
    
    a.show()
    
    +------+-----+
    |Animal|Enemy|
    +------+-----+
    |   Dog|  Cat|
    |   Cat|  Dog|
    | Mouse|  Cat|
    +------+-----+
    
    
    
    #convert list to a dataframe
    rating = [5,4,1]
    b = sqlContext.createDataFrame([(l,) for l in rating], ['Rating'])
    b.show()
    
    +------+
    |Rating|
    +------+
    |     5|
    |     4|
    |     1|
    +------+
    
    
    a = a.withColumn("idx", F.monotonically_increasing_id())
    b = b.withColumn("idx", F.monotonically_increasing_id())
    
    windowSpec = W.orderBy("idx")
    a = a.withColumn("idx", F.row_number().over(windowSpec))
    b = b.withColumn("idx", F.row_number().over(windowSpec))
    
    a.show()
    +------+-----+---+
    |Animal|Enemy|idx|
    +------+-----+---+
    |   Dog|  Cat|  1|
    |   Cat|  Dog|  2|
    | Mouse|  Cat|  3|
    +------+-----+---+
    
    b.show()
    +------+---+
    |Rating|idx|
    +------+---+
    |     5|  1|
    |     4|  2|
    |     1|  3|
    +------+---+
    
    final_df = a.join(b, a.idx == b.idx).drop("idx")
    
    +------+-----+------+
    |Animal|Enemy|Rating|
    +------+-----+------+
    |   Dog|  Cat|     5|
    |   Cat|  Dog|     4|
    | Mouse|  Cat|     1|
    +------+-----+------+
    
    0 讨论(0)
提交回复
热议问题