问题
I'm facing an issue with the OneHotEncoder of SparkML since it reads dataframe metadata in order to determine the value range it should assign for the sparse vector object its creating.
More specifically, I'm encoding a "hour" field using a training set containing all individual values between 0 and 23.
Now I'm scoring a single row data frame using the "transform" method od the Pipeline.
Unfortunately, this leads to a differently encoded sparse vector object for the OneHotEncoder
(24,[5],[1.0]) vs. (11,[10],[1.0])
I've documented this here, but this was identified as duplicate. So in this thread there is a solution posted to update the dataframes's metadata to reflect the real range of the "hour" field:
from pyspark.sql.functions import col
meta = {"ml_attr": {
"vals": [str(x) for x in range(6)], # Provide a set of levels
"type": "nominal",
"name": "class"}}
loaded.transform(
df.withColumn("class", col("class").alias("class", metadata=meta)) )
Unfortunalely I get this error:
TypeError: alias() got an unexpected keyword argument 'metadata'
回答1:
In PySpark 2.1, the alias
method has no argument metadata
(docs) - this became available in Spark 2.2; nevertheless, it is still possible to modify column metadata in PySpark < 2.2, thanks to the incredible Spark Gotchas, maintained by @eliasah and @zero323:
import json
from pyspark import SparkContext
from pyspark.sql import Column
from pyspark.sql.functions import col
spark.version
# u'2.1.1'
df = sc.parallelize((
(0, "x", 2.0),
(1, "y", 3.0),
(2, "x", -1.0)
)).toDF(["label", "x1", "x2"])
df.show()
# +-----+---+----+
# |label| x1| x2|
# +-----+---+----+
# | 0| x| 2.0|
# | 1| y| 3.0|
# | 2| x|-1.0|
# +-----+---+----+
Supposing that we want to enforce the possibility of our label
data to be between 0 and 5, despite that in our dataframe are between 0 and 2, here is how we should modify the column metadata:
def withMeta(self, alias, meta):
sc = SparkContext._active_spark_context
jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))
Column.withMeta = withMeta
# new metadata:
meta = {"ml_attr": {"name": "label_with_meta",
"type": "nominal",
"vals": [str(x) for x in range(6)]}}
df_with_meta = df.withColumn("label_with_meta", col("label").withMeta("", meta))
Kudos also to this answer by zero323!
来源:https://stackoverflow.com/questions/46667810/how-to-update-pyspark-dataframe-metadata-on-spark-2-1