基于spark mllib的gbt算法实例

背景:公司需要使用spark mllib进行预测,基于这个需求,使用spark mllib自带的gbm进行预测。

代码1:

博客
学院
下载
图文课
论坛
APP
问答
商城
VIP会员
活动
招聘
ITeye
GitChat

搜CSDN
写博客赚零钱传资源

关注和收藏在这里
Markdown编辑器
富文本编辑器
查看主页
内容
文章管理
专栏管理
评论管理
个人分类管理
Chat快问 new
博客搬家
设置
博客设置
栏目管理
CSDN博客QQ交流群


扫一扫二维码
或点击这里加入群聊


输入文章标题

文章标签:
添加标签
最多添加5个标签

个人分类:
添加新分类
文章类型:
 *
博客分类:
 *
私密文章:
  


import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql._
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
  *  使用spark自带的算法
  *  数据集为my_train.csv和my_test.csv
  *  这是个稳定的版本  只有测试成功再往里面加东西
  *
  */
object myCallXGBoost {
  Logger.getLogger("org").setLevel(Level.WARN)

  def main(args: Array[String]): Unit = {

    //val inputPath = args(0)
    val inputPath = "data"
    print("*******************"+inputPath)
    // create SparkSession
    val spark = SparkSession
      .builder()
      .appName("myCallXGBoost")
      .config("spark.executor.memory", "2G")
      .config("spark.executor.cores", "4")
      .config("hive.metastore.uris","thrift://xxxxxxxxxxxx")
      .config("spark.driver.memory", "1G")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.default.parallelism", "4")
      .enableHiveSupport()
      //.master("local[*]")
      .getOrCreate()


    //step1 数据准备工作
    //从csv中读取数据
    //val myTrainCsv = spark.read.option("header", "true").option("inferSchema", true).csv(  inputPath+"/my_train.csv")
    //val myTestCsv = spark.read.option("header", "true").option("inferSchema", true).csv(  inputPath+"/my_test.csv")

    val myTrainCsv = spark.sql("select * from dm_analysis.lsm_xgboost_train")
    val myTestCsv = spark.sql("select * from dm_analysis.lsm_xgboost_test")

    myTrainCsv.show(10)

    // 动态数据类型转化 将any类型转化为double
    def toDoubleDynamic(x: Any) = x match {
      case s: String => s.toDouble
      case jn: java.lang.Number => jn.doubleValue()
      case _ => throw new ClassCastException("cannot cast to double")
    }

    import spark.implicits._
    
    //这里查看到数据应该已经全部转换成double类型了
    myTrainCsv.printSchema()

    //这是一个比较完整的版本  需要将features的所有行添加进来
    //直接使用row  每行来给程序赋值
    //需要提前将数据处理成 (label,features...)的格式
    val mydata = myTrainCsv.drop("_c0").map{row =>
      val row_len = row.length
      var myList = new Array[Double](row_len-1)
      for(i<- 1 to (row_len-1)){
        myList(i-1) = toDoubleDynamic(row(i))
      }
      val features = Vectors.dense(myList)
      LabeledPoint(toDoubleDynamic(row(0)), features)
    }

    mydata.show(10)
    val splits = mydata.randomSplit(Array(0.7, 0.3))
    val (trainingData, testData) = (splits(0), splits(1))


    //step2 准备训练模型
    val boostingStrategy = BoostingStrategy.defaultParams("Regression")
    boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
    boostingStrategy.treeStrategy.maxDepth = 5
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

    val model = GradientBoostedTrees.train(trainingData.rdd, boostingStrategy)

    // Evaluate model on test instances and compute test error
    val labelsAndPredictions = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }
    val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}
    println(s"Test Mean Squared Error = $testMSE")
    println(s"Learned regression GBT model:\n ${model.toDebugString}")

    spark.stop()
  }
}

 

代码2:

猜你喜欢

转载自blog.csdn.net/abc50319/article/details/83866775
今日推荐