《Spark高级数据分析》——预测森林植被(决策树、随机森林)

《Spark高级数据分析》——预测森林植被(决策树、随机森林)

0. 简介

  • 来源: 《Spark高级数据分析》
  • 原书GitHub地址: https://github.com/sryza/aas
  • 内容简述:利用Spark中的决策树、随机森林算法,预测不同类型的森林植被

1. 数据准备

  • 读取森林植被特征、标签数据 covtype.data
val dataDF = loadData(spark)
dataDF.show()
def loadData(spark: SparkSession): DataFrame = {
 import spark.implicits._

 val dataWithoutHeaderDF = spark.read
   .option("inferSchema", true)
   .option("header", false)
   .csv("E:/Data/saa/Chapter4_covtype/covtype.data")

 // 重新定义字段名
 val colNames = Seq(
   "Elevation", "Aspect", "Slope",
   "Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
   "Horizontal_Distance_To_Roadways",
   "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
   "Horizontal_Distance_To_Fire_Points") ++
   (0 until 4).map(i => s"Wilderness_Area_$i") ++
   (0 until 40).map(i => s"Soil_Type_$i") ++
   Seq("Cover_Type")

 dataWithoutHeaderDF.toDF(colNames: _*)
   .withColumn("Cover_Type", $"Cover_Type".cast("double"))
}
  • 森林植被数据示例
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+
|Elevation|Aspect|Slope|Horizontal_Distance_To_Hydrology|Vertical_Distance_To_Hydrology|Horizontal_Distance_To_Roadways|Hillshade_9am|Hillshade_Noon|Hillshade_3pm|Horizontal_Distance_To_Fire_Points|Wilderness_Area_0|Wilderness_Area_1|Wilderness_Area_2|Wilderness_Area_3|Soil_Type_0|Soil_Type_1|Soil_Type_2|Soil_Type_3|Soil_Type_4|Soil_Type_5|Soil_Type_6|Soil_Type_7|Soil_Type_8|Soil_Type_9|Soil_Type_10|Soil_Type_11|Soil_Type_12|Soil_Type_13|Soil_Type_14|Soil_Type_15|Soil_Type_16|Soil_Type_17|Soil_Type_18|Soil_Type_19|Soil_Type_20|Soil_Type_21|Soil_Type_22|Soil_Type_23|Soil_Type_24|Soil_Type_25|Soil_Type_26|Soil_Type_27|Soil_Type_28|Soil_Type_29|Soil_Type_30|Soil_Type_31|Soil_Type_32|Soil_Type_33|Soil_Type_34|Soil_Type_35|Soil_Type_36|Soil_Type_37|Soil_Type_38|Soil_Type_39|Cover_Type|
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+
|     2596|    51|    3|                             258|                             0|                            510|          221|           232|          148|                              6279|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2590|    56|    2|                             212|                            -6|                            390|          220|           235|          151|                              6225|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2804|   139|    9|                             268|                            65|                           3180|          234|           238|          135|                              6121|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       2.0|
|     2785|   155|   18|                             242|                           118|                           3090|          238|           238|          122|                              6211|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       2.0|
|     2595|    45|    2|                             153|                            -1|                            391|          220|           234|          150|                              6172|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2579|   132|    6|                             300|                           -15|                             67|          230|           237|          140|                              6031|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       2.0|
|     2606|    45|    7|                             270|                             5|                            633|          222|           225|          138|                              6256|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2605|    49|    4|                             234|                             7|                            573|          222|           230|          144|                              6228|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2617|    45|    9|                             240|                            56|                            666|          223|           221|          133|                              6244|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2612|    59|   10|                             247|                            11|                            636|          228|           219|          124|                              6230|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2612|   201|    4|                             180|                            51|                            735|          218|           243|          161|                              6222|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2886|   151|   11|                             371|                            26|                           5253|          234|           240|          136|                              4051|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       2.0|
|     2742|   134|   22|                             150|                            69|                           3215|          248|           224|           92|                              6091|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       2.0|
|     2609|   214|    7|                             150|                            46|                            771|          213|           247|          170|                              6211|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2503|   157|    4|                              67|                             4|                            674|          224|           240|          151|                              5600|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2495|    51|    7|                              42|                             2|                            752|          224|           225|          137|                              5576|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2610|   259|    1|                             120|                            -1|                            607|          216|           239|          161|                              6096|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2517|    72|    7|                              85|                             6|                            595|          228|           227|          133|                              5607|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2504|     0|    4|                              95|                             5|                            691|          214|           232|          156|                              5572|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
|     2503|    38|    5|                              85|                            10|                            741|          220|           228|          144|                              5555|                1|                0|                0|                0|          0|          0|          0|          0|          0|          0|          0|          0|          0|          0|           0|           0|           0|           0|           0|           0|           0|           1|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|           0|       5.0|
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+

  • 拆分训练集、测试集
val Array(trainDF, testDF) = dataDF.randomSplit(Array(0.75, 0.25))
trainDF.persist()
testDF.persist()
  • 预处理训练集
val inputCols = trainDF.columns.filter(_ != "Cover_Type")
val assembler = new VectorAssembler()
  .setInputCols(inputCols)
  .setOutputCol("featureVector")

val assemblerTrainDF = assembler.transform(trainDF).persist()
assemblerTrainDF.select("featureVector").show(false)
  • 预处理测试集
val assemblerTestDF = new VectorAssembler()
   .setInputCols(inputCols)
   .setOutputCol("featureVector")
   .transform(testDF)
  • 训练集处理结果示例
+-----------------------------------------------------------------------------------------------------+
|featureVector                                                                                        |
+-----------------------------------------------------------------------------------------------------+
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1863.0,37.0,17.0,120.0,18.0,90.0,217.0,202.0,115.0,769.0,1.0,1.0])  |
|(54,[0,1,2,5,6,7,8,9,13,18],[1874.0,18.0,14.0,90.0,208.0,209.0,135.0,793.0,1.0,1.0])                 |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1879.0,28.0,19.0,30.0,12.0,95.0,209.0,196.0,117.0,778.0,1.0,1.0])   |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1888.0,33.0,22.0,150.0,46.0,108.0,209.0,185.0,103.0,735.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1889.0,28.0,22.0,150.0,23.0,120.0,205.0,185.0,108.0,759.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1889.0,353.0,30.0,95.0,39.0,67.0,153.0,172.0,146.0,600.0,1.0,1.0])  |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1896.0,337.0,12.0,30.0,6.0,175.0,195.0,224.0,168.0,732.0,1.0,1.0])  |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1898.0,34.0,23.0,175.0,56.0,134.0,210.0,184.0,99.0,765.0,1.0,1.0])  |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1899.0,355.0,22.0,153.0,43.0,124.0,178.0,195.0,151.0,819.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1901.0,311.0,9.0,30.0,2.0,190.0,195.0,234.0,179.0,726.0,1.0,1.0])   |
|(54,[0,1,2,3,4,5,6,7,8,9,13,16],[1903.0,67.0,16.0,108.0,36.0,120.0,234.0,207.0,100.0,969.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1905.0,19.0,27.0,134.0,58.0,120.0,188.0,171.0,108.0,636.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1905.0,33.0,27.0,90.0,46.0,150.0,204.0,171.0,89.0,725.0,1.0,1.0])   |
|(54,[0,1,2,3,4,5,6,7,8,9,13,16],[1905.0,77.0,21.0,90.0,38.0,120.0,241.0,196.0,75.0,1025.0,1.0,1.0])  |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1906.0,356.0,20.0,150.0,55.0,120.0,184.0,201.0,151.0,726.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1908.0,323.0,32.0,150.0,52.0,120.0,125.0,190.0,196.0,765.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1916.0,24.0,25.0,212.0,74.0,175.0,197.0,177.0,105.0,789.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1916.0,320.0,24.0,190.0,60.0,162.0,151.0,210.0,195.0,832.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,23],[1918.0,321.0,28.0,42.0,17.0,85.0,139.0,201.0,196.0,402.0,1.0,1.0])  |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1919.0,30.0,22.0,67.0,9.0,256.0,208.0,188.0,107.0,661.0,1.0,1.0])   |
+-----------------------------------------------------------------------------------------------------+

2. 训练决策树模型

  • 构建DecisionTreeClassifier模型,开始训练
// 构建模型
val classifier = new DecisionTreeClassifier()
    .setSeed(Random.nextLong())
    .setLabelCol("Cover_Type")
    .setFeaturesCol("featureVector")
    .setPredictionCol("prediction")

// 训练模型
val model = classifier.fit(assemblerTrainDF)
  • 打印决策模型、不同特征的信息增益
// 决策模型
 println(model.toDebugString)
// 不同特征的信息增益,降序
 model.featureImportances
   .toArray
   .zip(inputCols)
   .sorted.reverse
   .foreach(println)
  • 打印示例
DecisionTreeClassificationModel (uid=dtc_b7ddf2a70cb5) of depth 5 with 63 nodes
  If (feature 0 <= 3052.0)
   If (feature 0 <= 2558.0)
    If (feature 10 <= 0.0)
     If (feature 0 <= 2440.0)
      If (feature 3 <= 0.0)
       Predict: 4.0
      Else (feature 3 > 0.0)
       Predict: 3.0
     Else (feature 0 > 2440.0)
      If (feature 17 <= 0.0)
       Predict: 3.0
      Else (feature 17 > 0.0)
       Predict: 3.0
    ……

(0.7792536945752957,Elevation)
(0.03867758671456936,Horizontal_Distance_To_Hydrology)
(0.032035474824597066,Wilderness_Area_0)
(0.030258022977407074,Soil_Type_3)
(0.030002164397023114,Hillshade_Noon)
(0.027754291761557144,Soil_Type_31)
(0.023639745113770847,Soil_Type_1)
(0.010979405745834852,Wilderness_Area_2)
(0.010136754139592311,Soil_Type_28)
(0.006542158483011739,Soil_Type_22)
……
(0.0,Soil_Type_0)
(0.0,Slope)
(0.0,Hillshade_9am)
(0.0,Aspect)

3. 预测森林植被

  • 预测
val predictionDF = model.transform(assemblerTestDF)
predictionDF.persist()
predictionDF.select("Cover_Type", "prediction", "probability")
  .show(false)
  • 预测结果示例
+----------+----------+-------------------------------------------------------------------------------------------------+
|Cover_Type|prediction|probability                                                                                      |
+----------+----------+-------------------------------------------------------------------------------------------------+
|6.0       |3.0       |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0       |3.0       |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0       |3.0       |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0       |3.0       |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
……
+----------+----------+-------------------------------------------------------------------------------------------------+
  • 数据评分 accuracy + f1
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("Cover_Type")
  .setPredictionCol("prediction")

val accuracy = evaluator.setMetricName("accuracy").evaluate(predictionDF)
val f1 = evaluator.setMetricName("f1").evaluate(predictionDF)
println(s"accuracy = $accuracy, f1 = $f1")
  • 评分结果
accuracy = 0.6986190873428979, f1 = 0.6820440997673965

4. 利用网格搜索与交叉验证API

  • 构建管道模型
val pipeline = new Pipeline().setStages(Array(assembler, classifier))
  • 构建网格参数
val paramGrid = new ParamGridBuilder()
  .addGrid(classifier.impurity, Seq("gini", "entropy"))
  .addGrid(classifier.maxDepth, Seq(1, 20))
  .addGrid(classifier.maxBins, Seq(40, 300))
  .addGrid(classifier.minInfoGain, Seq(0.0, 0.05))
  .build()
  • 构建分类模型的评估器
val multiclassEvaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("Cover_Type")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
  • 开始网格搜索+交叉验证
// 构建模型
val validator = new TrainValidationSplit()
  .setSeed(Random.nextLong())
  .setEstimator(pipeline)
  .setEvaluator(multiclassEvaluator)
  .setEstimatorParamMaps(paramGrid)
  .setTrainRatio(0.8)

// 训练模型
val validatorModel = validator.fit(trainDF)
  • 查看训练结果的所有参数组合
validatorModel.validationMetrics
  .zip(validatorModel.getEstimatorParamMaps)
  .sortBy(-_._1)
  .foreach { case (metric, params) =>
    println("-----------------------------------------")
    println(metric)
    println(params)
  }
  • 获取训练结果的最佳模型,最佳参数
val bestModel = validatorModel.bestModel
println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap())
  • 最佳参数组合示例(准确率90%)
{
	dtc_b7ddf2a70cb5-cacheNodeIds: false,
	dtc_b7ddf2a70cb5-checkpointInterval: 10,
	dtc_b7ddf2a70cb5-featuresCol: featureVector,
	dtc_b7ddf2a70cb5-impurity: entropy,
	dtc_b7ddf2a70cb5-labelCol: Cover_Type,
	dtc_b7ddf2a70cb5-maxBins: 40,
	dtc_b7ddf2a70cb5-maxDepth: 20,
	dtc_b7ddf2a70cb5-maxMemoryInMB: 256,
	dtc_b7ddf2a70cb5-minInfoGain: 0.0,
	dtc_b7ddf2a70cb5-minInstancesPerNode: 1,
	dtc_b7ddf2a70cb5-predictionCol: prediction,
	dtc_b7ddf2a70cb5-probabilityCol: probability,
	dtc_b7ddf2a70cb5-rawPredictionCol: rawPrediction,
	dtc_b7ddf2a70cb5-seed: -6398219726571299260
}

5. 随机森林模型

  • 使用随机森林模型替换前面的决策树,提高准确率
val classifier = new RandomForestClassifier()
  .setSeed(Random.nextLong())
  .setLabelCol("Cover_Type")
  .setFeaturesCol("featureVector")
  .setPredictionCol("prediction")
  .setNumTrees(100)

6. 完整代码

import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.util.Random

/**
  * 第四章 - 决策树 - 预测森林植被
  *
  * @author ALion
  */
object RunRDF {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("Demo").setMaster("local[4]")
    val spark = SparkSession.builder()
      .config(conf)
      .enableHiveSupport()
      .getOrCreate()

    org.apache.log4j.Logger.getRootLogger.setLevel(
      org.apache.log4j.Level.toLevel("WARN")
    )

    import spark.implicits._

    // 1.准备数据
    val dataDF = loadData(spark)
    dataDF.show()

    // 2. 拆分数据集
    val Array(trainDF, testDF) = dataDF.randomSplit(Array(0.75, 0.25))
    trainDF.persist()
    testDF.persist()

    // 3. 预处理
    val inputCols = trainDF.columns.filter(_ != "Cover_Type")
    val assembler = new VectorAssembler()
      .setInputCols(inputCols)
      .setOutputCol("featureVector")
   // 训练集
    val assemblerTrainDF = assembler.transform(trainDF).persist()
    assemblerTrainDF.select("featureVector").show(false)
   // 测试集
    val assemblerTestDF = new VectorAssembler()
      .setInputCols(inputCols)
      .setOutputCol("featureVector")
      .transform(testDF)

    // 4. 构建决策树模型
    // val classifier = new DecisionTreeClassifier()
   //    .setSeed(Random.nextLong())
   //    .setLabelCol("Cover_Type")
   //    .setFeaturesCol("featureVector")
   //    .setPredictionCol("prediction")
      
    // 使用随机森林模型替换前面的决策树,提高准确率
        val classifier = new RandomForestClassifier()
           .setSeed(Random.nextLong())
           .setLabelCol("Cover_Type")
           .setFeaturesCol("featureVector")
           .setPredictionCol("prediction")
           .setNumTrees(100)

    // 训练数据
    val model = classifier.fit(assemblerTrainDF)

    println(model.toDebugString) // 打印决策模型

   // 打印不同特征的信息增益
    model.featureImportances
      .toArray
      .zip(inputCols)
      .sorted.reverse
      .foreach(println)

    // 5. 预测植被
    val predictionDF = model.transform(assemblerTestDF)
    predictionDF.persist()
    predictionDF.select("Cover_Type", "prediction", "probability")
      .show(false)

    // 评分
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("Cover_Type")
      .setPredictionCol("prediction")

    val accuracy = evaluator.setMetricName("accuracy").evaluate(predictionDF)
    val f1 = evaluator.setMetricName("f1").evaluate(predictionDF)
    println(s"accuracy = $accuracy, f1 = $f1")

    // 计算混淆矩阵
    // 方法1
    val predictionRDD = predictionDF
      .select("prediction", "Cover_Type")
      .as[(Double, Double)]
      .rdd

    val multiclassMetrics = new MulticlassMetrics(predictionRDD)
    println(multiclassMetrics.confusionMatrix)

    // 方法2
    val confusionMatrix = predictionDF
      .groupBy("Cover_Type")
      .pivot("prediction", 1 to 7)
      .count()
      .na.fill(0.0)
      .orderBy("Cover_Type")
    confusionMatrix.show()


    // 6. 网格搜索+交叉验证
    // 构建管道模型
    val pipeline = new Pipeline().setStages(Array(assembler, classifier))
    // 构建网格参数
    val paramGrid = new ParamGridBuilder()
      .addGrid(classifier.impurity, Seq("gini", "entropy"))
      .addGrid(classifier.maxDepth, Seq(1, 20))
      .addGrid(classifier.maxBins, Seq(40, 300))
      .addGrid(classifier.minInfoGain, Seq(0.0, 0.05))
      .build()
    // 构建分类模型的评估器
    val multiclassEvaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("Cover_Type")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    
    // 开始网格搜索+交叉验证
    val validator = new TrainValidationSplit()
      .setSeed(Random.nextLong())
      .setEstimator(pipeline)
      .setEvaluator(multiclassEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setTrainRatio(0.8)

    val validatorModel = validator.fit(trainDF)

    // 获取训练结果的最佳模型,最佳参数
    val bestModel = validatorModel.bestModel
    println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap())

    // 查看所有参数组合
    validatorModel.validationMetrics
      .zip(validatorModel.getEstimatorParamMaps)
      .sortBy(-_._1)
      .foreach { case (metric, params) =>
        println("-----------------------------------------")
        println(metric)
        println(params)
      }


    spark.stop()
  }

  /**
    * 加载原始数据
    * @param spark SparkSession
    * @return DataFrame
    */
  def loadData(spark: SparkSession): DataFrame = {
    import spark.implicits._

    val dataWithoutHeaderDF = spark.read
      .option("inferSchema", true)
      .option("header", false)
      .csv("E:/Data/saa/Chapter4_covtype/covtype.data")

    // 重新定义字段名
    val colNames = Seq(
      "Elevation", "Aspect", "Slope",
      "Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
      "Horizontal_Distance_To_Roadways",
      "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
      "Horizontal_Distance_To_Fire_Points") ++
      (0 until 4).map(i => s"Wilderness_Area_$i") ++
      (0 until 40).map(i => s"Soil_Type_$i") ++
      Seq("Cover_Type")

    dataWithoutHeaderDF.toDF(colNames: _*)
      .withColumn("Cover_Type", $"Cover_Type".cast("double"))
  }

}

发布了128 篇原创文章 · 获赞 45 · 访问量 15万+

猜你喜欢

转载自blog.csdn.net/alionsss/article/details/91152789
今日推荐