【推荐】2019 Java 开发者跳槽指南.pdf(吐血整理) >>>
## mmlspark
https://mvnrepository.com/artifact/Azure/mmlspark/0.15
## lightgbmlib
https://mvnrepository.com/artifact/com.microsoft.ml.lightgbm/lightgbmlib/2.2.200
[root@hadoop-1-1 ~]# more lgbm.sh
/app/spark2.3/bin/spark-submit \
--master yarn \
--jars /root/external_pkgs/mmlspark-0.15.jar,/root/external_pkgs/lightgbmlib-2.2.200.jar \
--class com.sf.demo.lgmClassifier /root/lgbm_demo.jar
nohup sh lgbm.sh > lgbm_20191226_001.log 2>&1 &
package com.xx.demo
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.StandardScaler
import com.microsoft.ml.spark.LightGBMClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
object lgmClassifier {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("lgbm_app").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("yarn.nodemanager.vmem-check-enabled", "false")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val input_path = "/user/spark/H2O/data/PimaIndian.csv"
val data = sparkSession.sqlContext.read.format("csv").option("sep", ",")
.option("inferSchema", "true")
.option("header", "false")
.load(input_path)
val schemas= Seq("Pregnancles","Glucose","BloodPressure","SkinThickness","Insulin","BMI","DiabetesPedigreeFuction","Age","Outcome")
val dataset = data.toDF(schemas:_*)
val vectorAssembler = new VectorAssembler().setInputCols(dataset.columns.filter(!_.contains("Outcome"))).setOutputCol("features")
val scaler = (new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithStd(true)
.setWithMean(false))
val lgbm = new LightGBMClassifier().setLabelCol("Outcome").setFeaturesCol("scaledFeatures")
val pipeline = new Pipeline().setStages(Array(vectorAssembler, scaler, lgbm))
val paramGrid = (new ParamGridBuilder()
.addGrid(lgbm.learningRate, Array(0.05,0.1))
.build())
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
.setLabelCol("Outcome")
.setRawPredictionCol("prediction")
.setMetricName("areaUnderROC"))
// Create the Cross Validation pipeline
val cv = (new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setSeed(0))
// Split training and test dataset
val Array(training, test) = dataset.randomSplit(Array(0.8, 0.2), 0)
val lgbmModel = cv.fit(training)
val results = lgbmModel.transform(test)
val auc = evaluator.evaluate(results)
println("----AUC--------")
println(s"The model's auc: $auc")
sparkSession.stop()
}
}
来源:oschina
链接:https://my.oschina.net/kyo4321/blog/3147761