Caching intermediate results in Spark ML pipeline

后端 未结 1 568
我在风中等你
我在风中等你 2021-02-05 13:21

Lately I\'m planning to migrate my standalone python ML code to spark. The ML pipeline in spark.ml turns out quite handy, with streamlined API for chaining up algor

1条回答
  •  夕颜
    夕颜 (楼主)
    2021-02-05 14:17

    So I ran into the same problem and the way I solved is that I have implemented my own PipelineStage, that caches the input DataSet and returns it as it is.

    import org.apache.spark.ml.Transformer
    import org.apache.spark.ml.param.ParamMap
    import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
    import org.apache.spark.sql.{DataFrame, Dataset}
    import org.apache.spark.sql.types.StructType
    
    class Cacher(val uid: String) extends Transformer with DefaultParamsWritable {
      override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF.cache()
    
      override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
    
      override def transformSchema(schema: StructType): StructType = schema
    
      def this() = this(Identifiable.randomUID("CacherTransformer"))
    }
    

    To use it then you would do something like this:

    new Pipeline().setStages(Array(stage1, new Cacher(), stage2))
    

    0 讨论(0)
提交回复
热议问题