问题
I am trying to remove an element from a Python list of lists:
+---------------+
| sources|
+---------------+
| [62]|
| [7, 32]|
| [62]|
| [18, 36, 62]|
|[7, 31, 36, 62]|
| [7, 32, 62]|
I want to be able to remove an element, rm
, from each of the lists in the list above. I wrote a function that can do that for a list of lists:
def asdf(df, rm):
temp = df
for n in range(len(df)):
temp[n] = [x for x in df[n] if x != rm]
return(temp)
which does remove rm = 1
:
a = [[1,2,3],[1,2,3,4],[1,2,3,4,5]]
In: asdf(a,1)
Out: [[2, 3], [2, 3, 4], [2, 3, 4, 5]]
But I can't get it to work for a DataFrame:
asdfUDF = udf(asdf, ArrayType(IntegerType()))
In: df.withColumn("src_ex", asdfUDF("sources", 32))
Out: Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.lang.Integer]) does not exist
Desired behavior:
In: df.withColumn("src_ex", asdfUDF("sources", 32))
Out:
+---------------+
| src_ex|
+---------------+
| [62]|
| [7]|
| [62]|
| [18, 36, 62]|
|[7, 31, 36, 62]|
| [7, 62]|
(except have the new column above appended to a PySpark DataFrame, df
)
Any suggestions or ideas?
回答1:
Spark >= 2.4
You can use array_remove
:
from pyspark.sql.functions import array_remove
df.withColumn("src_ex", array_remove("sources", 32)).show()
+---------------+---------------+
| sources| src_ex|
+---------------+---------------+
| [62]| [62]|
| [7, 32]| [7]|
| [62]| [62]|
| [18, 36, 62]| [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
| [7, 32, 62]| [7, 62]|
+---------------+---------------+
or filter
:
from pyspark.sql.functions import expr
df.withColumn("src_ex", expr("filter(sources, x -> not(x <=> 32))")).show()
+---------------+---------------+
| sources| src_ex|
+---------------+---------------+
| [62]| [62]|
| [7, 32]| [7]|
| [62]| [62]|
| [18, 36, 62]| [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
| [7, 32, 62]| [7, 62]|
+---------------+---------------+
Spark < 2.4
A number of things:
DataFrame
is not a list of lists. In practice it is not even a plain Python object, it has nolen
and it is notIterable
.- Column you have looks like plain
array
type. - You cannot reference
DataFrame
(or any other distributed data structure inside UDF). - Every argument passed directly to UDF call has to be a
str
(column name) orColumn
object. To pass literal uselit
function.
The only thing that remains is just a list comprehension:
from pyspark.sql.functions import lit, udf
def drop_from_array_(arr, item):
return [x for x in arr if x != item]
drop_from_array = udf(drop_from_array_, ArrayType(IntegerType()))
Example usage:
df = sc.parallelize([
[62], [7, 32], [62], [18, 36, 62], [7, 31, 36, 62], [7, 32, 62]
]).map(lambda x: (x, )).toDF(["sources"])
df.withColumn("src_ex", drop_from_array("sources", lit(32)))
The result:
+---------------+---------------+
| sources| src_ex|
+---------------+---------------+
| [62]| [62]|
| [7, 32]| [7]|
| [62]| [62]|
| [18, 36, 62]| [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
| [7, 32, 62]| [7, 62]|
+---------------+---------------+
来源:https://stackoverflow.com/questions/41624567/remove-an-element-from-a-python-list-of-lists-in-pyspark-dataframe