第一个xgb的例子,还算顺利
public class Xgb2 implements Serializable{
/**
*
*/
private static final long serialVersionUID = 1L;
public static void main(String[] args) throws Exception {
Long start=System.currentTimeMillis();
SparkSession s= SparkSession.builder().appName("XgbTrain").getOrCreate();
Xgb2 main=new Xgb2();
main.udf(s);
main.read(s);
HttpClientUtil.sendDingMessage("耗时:"+(System.currentTimeMillis()-start)/1000+"s");
}
public void udf(SparkSession s) {
s.udf().register("strToDouble", new UDF1<String, Double>() {
private static final long serialVersionUID = 1L;
@Override
public Double call(String org) throws Exception {
return Double.parseDouble(org.trim());
}
}, DataTypes.DoubleType);
}
public void read(SparkSession s) throws Exception {
Dataset<Row> rs = s.read().csv("hdfs://*/iris.data");
rs.createOrReplaceTempView("temp");
rs=s.sql("select strToDouble(_c0) as _c0, strToDouble(_c1) as _c1, strToDouble(_c2) as _c2, strToDouble(_c3) as _c3, _c4 from temp");
rs.show(10, false);
StringIndexerModel stringIndexer = new StringIndexer()
.setInputCol("_c4")
.setOutputCol("classIndex")
.fit(rs);
// 执行进行转换,并把原有的字符串class删除掉
rs = stringIndexer.transform(rs).drop("_c4");
// 将多个字段合并成在一起,组成future
VectorAssembler vectorAssembler = new VectorAssembler()
.setInputCols(new String[] {"_c0", "_c1", "_c2", "_c3"})
.setOutputCol("features");
//将数据集切分成训集和测试集
rs = vectorAssembler.transform(rs).select("features", "classIndex");
Dataset<Row>[] splitXgbInput = rs.randomSplit(new double[] {0.9, 0.1});
Dataset<Row> trainXgbInput = splitXgbInput[0];
Dataset<Row> testXgbInput = splitXgbInput[1];
Map<String,Object> javaMap=new HashMap<>();
javaMap.put("objective","multi:softprob");
javaMap.put("eta",0.1);
javaMap.put("max_depth",2);
javaMap.put("num_round","20");
javaMap.put("num_class",3);
XGBoostClassifier xgBoostEstimator=new XGBoostClassifier( ConvertUtil.<String,Object>toScalaImmutableMap(javaMap))
.setFeaturesCol("features").setLabelCol("classIndex");
// 开始训练
XGBoostClassificationModel model = xgBoostEstimator.fit(trainXgbInput);
// 预测
Dataset<Row> result = model.transform(testXgbInput);
// 展示
result.show(10,false);
}
}
来源:oschina
链接:https://my.oschina.net/u/4391429/blog/4267871