集群提交lightGBM算法

烈酒焚心 提交于 2019-12-26 17:11:17

【推荐】2019 Java 开发者跳槽指南.pdf(吐血整理) >>>

## mmlspark
https://mvnrepository.com/artifact/Azure/mmlspark/0.15

## lightgbmlib
https://mvnrepository.com/artifact/com.microsoft.ml.lightgbm/lightgbmlib/2.2.200
[root@hadoop-1-1 ~]# more lgbm.sh
/app/spark2.3/bin/spark-submit \
--master yarn \
--jars /root/external_pkgs/mmlspark-0.15.jar,/root/external_pkgs/lightgbmlib-2.2.200.jar \
--class com.sf.demo.lgmClassifier /root/lgbm_demo.jar
nohup sh lgbm.sh > lgbm_20191226_001.log 2>&1 &
package com.xx.demo

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.StandardScaler
import com.microsoft.ml.spark.LightGBMClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}

object lgmClassifier {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("lgbm_app").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("yarn.nodemanager.vmem-check-enabled", "false")
    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    val input_path = "/user/spark/H2O/data/PimaIndian.csv"
    val data = sparkSession.sqlContext.read.format("csv").option("sep", ",")
      .option("inferSchema", "true")
      .option("header", "false")
      .load(input_path)

    val schemas= Seq("Pregnancles","Glucose","BloodPressure","SkinThickness","Insulin","BMI","DiabetesPedigreeFuction","Age","Outcome")
    val dataset = data.toDF(schemas:_*)

    val vectorAssembler = new VectorAssembler().setInputCols(dataset.columns.filter(!_.contains("Outcome"))).setOutputCol("features")

    val scaler = (new StandardScaler()
      .setInputCol("features")
      .setOutputCol("scaledFeatures")
      .setWithStd(true)
      .setWithMean(false))

    val lgbm = new LightGBMClassifier().setLabelCol("Outcome").setFeaturesCol("scaledFeatures")

    val pipeline = new Pipeline().setStages(Array(vectorAssembler, scaler, lgbm))

    val paramGrid = (new ParamGridBuilder()
      .addGrid(lgbm.learningRate, Array(0.05,0.1))
      .build())

    // Setup the binary classifier evaluator
    val evaluator = (new BinaryClassificationEvaluator()
      .setLabelCol("Outcome")
      .setRawPredictionCol("prediction")
      .setMetricName("areaUnderROC"))

    // Create the Cross Validation pipeline
    val cv = (new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setSeed(0))

    // Split training and test dataset
    val Array(training, test) = dataset.randomSplit(Array(0.8, 0.2), 0)


    val lgbmModel = cv.fit(training)

    val results = lgbmModel.transform(test)

    val auc = evaluator.evaluate(results)
    println("----AUC--------")
    println(s"The model's auc: $auc")

    sparkSession.stop()
  }

}

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!