get first N elements from dataframe ArrayType column in pyspark

后端 未结 2 1068
逝去的感伤
逝去的感伤 2020-12-09 20:26

I have a spark dataframe with rows as -

1   |   [a, b, c]
2   |   [d, e, f]
3   |   [g, h, i]

Now I want to keep only the first 2 elements

相关标签:
2条回答
  • 2020-12-09 20:40

    Here's how to do it with the API functions.

    Suppose your DataFrame were the following:

    df.show()
    #+---+---------+
    #| id|  letters|
    #+---+---------+
    #|  1|[a, b, c]|
    #|  2|[d, e, f]|
    #|  3|[g, h, i]|
    #+---+---------+
    
    df.printSchema()
    #root
    # |-- id: long (nullable = true)
    # |-- letters: array (nullable = true)
    # |    |-- element: string (containsNull = true)
    

    You can use square brackets to access elements in the letters column by index, and wrap that in a call to pyspark.sql.functions.array() to create a new ArrayType column.

    import pyspark.sql.functions as f
    
    df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
    #+---+---------+---------+
    #| id|  letters|first_two|
    #+---+---------+---------+
    #|  1|[a, b, c]|   [a, b]|
    #|  2|[d, e, f]|   [d, e]|
    #|  3|[g, h, i]|   [g, h]|
    #+---+---------+---------+
    

    Or if you had too many indices to list, you can use a list comprehension:

    df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
    #+---+---------+---------+
    #| id|  letters|first_two|
    #+---+---------+---------+
    #|  1|[a, b, c]|   [a, b]|
    #|  2|[d, e, f]|   [d, e]|
    #|  3|[g, h, i]|   [g, h]|
    #+---+---------+---------+
    

    For pyspark versions 2.4+ you can also use pyspark.sql.functions.slice():

    df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
    #+---+---------+---------+
    #| id|  letters|first_two|
    #+---+---------+---------+
    #|  1|[a, b, c]|   [a, b]|
    #|  2|[d, e, f]|   [d, e]|
    #|  3|[g, h, i]|   [g, h]|
    #+---+---------+---------+
    

    slice may have better performance for large arrays (note that start index is 1, not 0)

    0 讨论(0)
  • 2020-12-09 20:44

    Either my pyspark skills have gone rusty (I confess I don't hone them much anymore nowadays), or this is a tough nut indeed... The only way I managed to do it is by using SQL statements:

    spark.version
    #  u'2.3.1'
    
    # dummy data:
    
    from pyspark.sql import Row
    x = [Row(col1="xx", col2="yy", col3="zz", col4=[123,234, 456])]
    rdd = sc.parallelize(x)
    df = spark.createDataFrame(rdd)
    df.show()
    # result:
    +----+----+----+---------------+
    |col1|col2|col3|           col4|
    +----+----+----+---------------+
    |  xx|  yy|  zz|[123, 234, 456]|
    +----+----+----+---------------+
    
    df.createOrReplaceTempView("df")
    df2 = spark.sql("SELECT col1, col2, col3, (col4[0], col4[1]) as col5 FROM df")
    df2.show()
    # result:
    +----+----+----+----------+ 
    |col1|col2|col3|      col5|
    +----+----+----+----------+ 
    |  xx|  yy|  zz|[123, 234]|
    +----+----+----+----------+
    

    For future questions, it would be good to follow the suggested guidelines on How to make good reproducible Apache Spark Dataframe examples.

    0 讨论(0)
提交回复
热议问题