Lately I\'m planning to migrate my standalone python ML code to spark. The ML pipeline in spark.ml
turns out quite handy, with streamlined API for chaining up algor
So I ran into the same problem and the way I solved is that I have implemented my own PipelineStage, that caches the input DataSet and returns it as it is.
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
class Cacher(val uid: String) extends Transformer with DefaultParamsWritable {
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF.cache()
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = schema
def this() = this(Identifiable.randomUID("CacherTransformer"))
}
To use it then you would do something like this:
new Pipeline().setStages(Array(stage1, new Cacher(), stage2))