How to eval spark.ml model without DataFrames/SparkContext?

前端 未结 3 616
时光说笑
时光说笑 2021-01-22 22:45

With Spark MLLib, I\'d build a model (like RandomForest), and then it was possible to eval it outside of Spark by loading the model and using predict o

3条回答
  •  清酒与你
    2021-01-22 23:21

    Here is my solution to use spark models outside of spark context (using PMML):

    1. You create model with a pipeline like this:

    SparkConf sparkConf = new SparkConf();

    SparkSession session = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate();   
    String tableName = "schema.table";
    Properties dbProperties = new Properties();
    dbProperties.setProperty("user",vKey);
    dbProperties.setProperty("password",password);
    dbProperties.setProperty("AuthMech","3");
    dbProperties.setProperty("source","jdbc");
    dbProperties.setProperty("driver","com.cloudera.impala.jdbc41.Driver");
    String tableName = "schema.table";
    String simpleUrl = "jdbc:impala://host:21050/schema"
    Dataset data = session.read().jdbc(simpleUrl ,tableName,dbProperties);
    String[] inputCols = {"column1"};
    StringIndexer indexer = new StringIndexer().setInputCol("column1").setOutputCol("indexed_column1");
    StringIndexerModel alphabet  = indexer.fit(data);
    data = alphabet.transform(data);
    VectorAssembler assembler = new VectorAssembler().setInputCols(inputCols).setOutputCol("features");
    Predictor p = new GBTRegressor();
    p.set("maxIter",20);
    p.set("maxDepth",2);
    p.set("maxBins",204);
    p.setLabelCol("faktor");
    PipelineStage[] stages = {indexer,assembler, p};
    Pipeline pipeline = new Pipeline();
    pipeline.setStages(stages);
    PipelineModel pmodel = pipeline.fit(data);
    PMML pmml = ConverterUtil.toPMML(data.schema(),pmodel);
    FileOutputStream fos = new FileOutputStream("model.pmml");
    JAXBUtil.marshalPMML(pmml,new StreamResult(fos));
    
    1. Using PPML for predictions (locally, without spark context, which can be applied to a Map of arguments and not on a DataFrame):

      PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
      ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
      MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
      inputFieldMap = new HashMap();     
      Map args = new HashMap();
      Field curField = evaluator.getInputFields().get(0);
      args.put(curField.getName(), "1.0");
      Map result  = evaluator.evaluate(args);
      

提交回复
热议问题