Preserve index-string correspondence spark string indexer

后端 未结 1 1421
天涯浪人
天涯浪人 2021-02-04 06:20

Spark\'s StringIndexer is quite useful, but it\'s common to need to retrieve the correspondences between the generated index values and the original strings, and it seems like t

1条回答
  •  礼貌的吻别
    2021-02-04 07:04

    Label mapping can extracted from the column metadata:

    meta = [
        f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex"
    ]
    meta[0]
    ## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}}
    

    where ml_attr.vals provide a mapping between position and label:

    dict(enumerate(meta[0]["ml_attr"]["vals"]))
    ## {0: 'a', 1: 'c', 2: 'b'}
    

    Spark 1.6+

    You can convert numeric values to labels using IndexToString. This will use column metadata as shown above.

    from pyspark.ml.feature import IndexToString
    
    idx_to_string = IndexToString(
        inputCol="categoryIndex", outputCol="categoryValue")
    
    idx_to_string.transform(indexed_df).drop("id").distinct().show()
    ## +--------+-------------+-------------+
    ## |category|categoryIndex|categoryValue|
    ## +--------+-------------+-------------+
    ## |       b|          2.0|            b|
    ## |       a|          0.0|            a|
    ## |       c|          1.0|            c|
    ## +--------+-------------+-------------+
    

    Spark <= 1.5

    It is a dirty hack but you can simply extract labels from a Java indexer as follows:

    from pyspark.ml.feature import StringIndexerModel
    
    # A simple monkey patch so we don't have to _call_java later 
    def labels(self):
        return self._call_java("labels")
    
    StringIndexerModel.labels = labels
    
    # Fit indexer model
    indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df)
    
    # Extract mapping
    mapping = dict(enumerate(indexer.labels()))
    mapping
    ## {0: 'a', 1: 'c', 2: 'b'}
    

    0 讨论(0)
提交回复
热议问题