spark-xgboost8.1 java 例子

旧时模样 提交于 2020-05-06 10:18:02

第一个xgb的例子,还算顺利

public class Xgb2 implements Serializable{

    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public static void main(String[] args) throws Exception {       

         Long start=System.currentTimeMillis();
         SparkSession s= SparkSession.builder().appName("XgbTrain").getOrCreate();
         Xgb2 main=new Xgb2();
         main.udf(s);
         main.read(s);
         HttpClientUtil.sendDingMessage("耗时:"+(System.currentTimeMillis()-start)/1000+"s");

    }

    public void udf(SparkSession s) {

         s.udf().register("strToDouble", new UDF1<String, Double>() {
                private static final long serialVersionUID = 1L;

                @Override
                public Double call(String org) throws Exception {
                    return Double.parseDouble(org.trim());
                }
            }, DataTypes.DoubleType);
    }

    public void read(SparkSession s) throws Exception {

        Dataset<Row> rs = s.read().csv("hdfs://*/iris.data");

        rs.createOrReplaceTempView("temp");

        rs=s.sql("select strToDouble(_c0) as _c0,  strToDouble(_c1) as _c1, strToDouble(_c2) as _c2, strToDouble(_c3) as _c3, _c4 from temp");

        rs.show(10, false);

        StringIndexerModel stringIndexer = new StringIndexer()
                  .setInputCol("_c4")
                  .setOutputCol("classIndex")
                  .fit(rs);
                // 执行进行转换,并把原有的字符串class删除掉
        rs = stringIndexer.transform(rs).drop("_c4");
                // 将多个字段合并成在一起,组成future
        VectorAssembler vectorAssembler = new VectorAssembler()
                  .setInputCols(new String[] {"_c0", "_c1", "_c2", "_c3"})
                  .setOutputCol("features");
                //将数据集切分成训集和测试集
        rs = vectorAssembler.transform(rs).select("features", "classIndex");
        Dataset<Row>[]  splitXgbInput = rs.randomSplit(new double[] {0.9, 0.1});
        Dataset<Row> trainXgbInput = splitXgbInput[0];
        Dataset<Row> testXgbInput = splitXgbInput[1];

        Map<String,Object> javaMap=new HashMap<>();
        javaMap.put("objective","multi:softprob");
        javaMap.put("eta",0.1);
        javaMap.put("max_depth",2);
        javaMap.put("num_round","20");
        javaMap.put("num_class",3);

        XGBoostClassifier xgBoostEstimator=new XGBoostClassifier( ConvertUtil.<String,Object>toScalaImmutableMap(javaMap))
                .setFeaturesCol("features").setLabelCol("classIndex");

                // 开始训练
        XGBoostClassificationModel model = xgBoostEstimator.fit(trainXgbInput);
                // 预测
        Dataset<Row> result = model.transform(testXgbInput);
                // 展示 
        result.show(10,false);

    }
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!