How to update pyspark dataframe metadata on Spark 2.1?

主宰稳场 提交于 2019-12-11 14:57:40

问题


I'm facing an issue with the OneHotEncoder of SparkML since it reads dataframe metadata in order to determine the value range it should assign for the sparse vector object its creating.

More specifically, I'm encoding a "hour" field using a training set containing all individual values between 0 and 23.

Now I'm scoring a single row data frame using the "transform" method od the Pipeline.

Unfortunately, this leads to a differently encoded sparse vector object for the OneHotEncoder

(24,[5],[1.0]) vs. (11,[10],[1.0])

I've documented this here, but this was identified as duplicate. So in this thread there is a solution posted to update the dataframes's metadata to reflect the real range of the "hour" field:

from pyspark.sql.functions import col

meta = {"ml_attr": {
    "vals": [str(x) for x in range(6)],   # Provide a set of levels
    "type": "nominal", 
    "name": "class"}}

loaded.transform(
    df.withColumn("class", col("class").alias("class", metadata=meta)) )

Unfortunalely I get this error:

TypeError: alias() got an unexpected keyword argument 'metadata'


回答1:


In PySpark 2.1, the alias method has no argument metadata (docs) - this became available in Spark 2.2; nevertheless, it is still possible to modify column metadata in PySpark < 2.2, thanks to the incredible Spark Gotchas, maintained by @eliasah and @zero323:

import json

from pyspark import SparkContext
from pyspark.sql import Column
from pyspark.sql.functions import col

spark.version
# u'2.1.1'

df = sc.parallelize((
        (0, "x", 2.0),
        (1, "y", 3.0),
        (2, "x", -1.0)
        )).toDF(["label", "x1", "x2"])

df.show()
# +-----+---+----+ 
# |label| x1|  x2|
# +-----+---+----+
# |    0|  x| 2.0|
# |    1|  y| 3.0|
# |    2|  x|-1.0|
# +-----+---+----+

Supposing that we want to enforce the possibility of our label data to be between 0 and 5, despite that in our dataframe are between 0 and 2, here is how we should modify the column metadata:

def withMeta(self, alias, meta):
    sc = SparkContext._active_spark_context
    jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
    return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))

Column.withMeta = withMeta

# new metadata:
meta = {"ml_attr": {"name": "label_with_meta",
                    "type": "nominal",
                    "vals": [str(x) for x in range(6)]}}

df_with_meta = df.withColumn("label_with_meta", col("label").withMeta("", meta))

Kudos also to this answer by zero323!



来源:https://stackoverflow.com/questions/46667810/how-to-update-pyspark-dataframe-metadata-on-spark-2-1

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!