问题
I have the following PySpark DataFrame:
+------+----------------+
| id| data |
+------+----------------+
| 1| [10, 11, 12]|
| 2| [20, 21, 22]|
| 3| [30, 31, 32]|
+------+----------------+
At the end, I want to have the following DataFrame
+--------+----------------------------------+
| id | data |
+--------+----------------------------------+
| [1,2,3]|[[10,20,30],[11,21,31],[12,22,32]]|
+--------+----------------------------------+
I order to do this. First I extract the data arrays as follow:
tmp_array = df_test.select("data").rdd.flatMap(lambda x: x).collect()
a0 = tmp_array[0]
a1 = tmp_array[1]
a2 = tmp_array[2]
samples = zip(a0, a1, a2)
samples1 = sc.parallelize(samples)
In this way, I have in samples1 an RDD with the content
[[10,20,30],[11,21,31],[12,22,32]]
Question 1: Is that a good way to do it?
Question 2: How to include that RDD back into the dataframe?
回答1:
Here is a way to get your desired output without serializing to rdd
or using a udf
. You will need two constants:
- The number of rows in your DataFrame (
df.count()
) - The length of data (given)
Use pyspark.sql.functions.collect_list() and pyspark.sql.functions.array() in a double list comprehension to pick out the elements of "data"
in the order you want using pyspark.sql.Column.getItem():
import pyspark.sql.functions as f
dataLength = 3
numRows = df.count()
df.select(
f.collect_list("id").alias("id"),
f.array(
[
f.array(
[f.collect_list("data").getItem(j).getItem(i)
for j in range(numRows)]
)
for i in range(dataLength)
]
).alias("data")
)\
.show(truncate=False)
#+---------+------------------------------------------------------------------------------+
#|id |data |
#+---------+------------------------------------------------------------------------------+
#|[1, 2, 3]|[WrappedArray(10, 20, 30), WrappedArray(11, 21, 31), WrappedArray(12, 22, 32)]|
#+---------+------------------------------------------------------------------------------+
回答2:
You can simply use a udf
function for the zip function but before that you will have to use collect_list
function
from pyspark.sql import functions as f
from pyspark.sql import types as t
def zipUdf(array):
return zip(*array)
zipping = f.udf(zipUdf, t.ArrayType(t.ArrayType(t.IntegerType())))
df.select(
f.collect_list(df.id).alias('id'),
zipping(f.collect_list(df.data)).alias('data')
).show(truncate=False)
which would give you
+---------+------------------------------------------------------------------------------+
|id |data |
+---------+------------------------------------------------------------------------------+
|[1, 2, 3]|[WrappedArray(10, 20, 30), WrappedArray(11, 21, 31), WrappedArray(12, 22, 32)]|
+---------+------------------------------------------------------------------------------+
来源:https://stackoverflow.com/questions/49800484/pyspark-zip-arrays-in-a-dataframe