Spark决策树算法预测实现(Scala语言)

1. 准备数据

使用网上的Covtype 数据集,包含csv格式压缩数据文件 covtype.data.gz;

解压缩:

gzip covtype.data.gz -d

数据集记录科罗拉多州不同地块的森林植被类型。

Spark MLlib将特征抽象为 LabeledPoint; 由包含多个特征值的Spark MLlib Vector 和一个 label 的目标值组成。目标为double 类型; Vector本身也是多个 Double 类型值的抽象。

LabeledPoint 只适用于数值型特征;类别型特征需要进行 one-hot, label-encoder 等编码。

Covtype 数据集的类别特征已经过处理成数值型了。

2. 决策树模型

可以直接读取数据,或者把数据复制到HDFS(/user/ds 目录下).

把数据集分成训练集,验证集,测试集。

import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression._

val rawData = sc.textFile("/data/covtype.data")

val data = rawData.map{line =>
    val values = line.split(',').map(_.toDouble)
    val featureVector = Vectors.dense(values.init)
    val label = values.last - 1
    LabeledPoint(label,featureVector)
}

val Array(trainData,cvData,testData) = 
    data.randomSplit(Array(0.8,0.1,0.1))
trainData.cache()
cvData.cache()
testData.cache()

决策树模型:

import org.apache.spark.mllib.evaluation._
import org.apache.spark.mllib.tree._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd._

def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]):
    MulticlassMetrics = {
        val predictionsAndLables = data.map(example =>
          (model.predict(example.features), example.label)
      )
        new MulticlassMetrics(predictionsAndLables)
    }
val model = DecisionTree.trainClassifier(
trainData, 7, Map[Int,Int](), "gini", 4, 100)
val metrics = getMetrics(model, cvData)

采用gini参数,最大深度4,最大桶数100。

MulticlassMetrics 以不同方式计算分类器预测质量的标准指标。

获取运算结果的混淆矩阵。

metrics.confusionMatrix

运行结果:

res1: org.apache.spark.mllib.linalg.Matrix =
13955.0  6703.0   9.0     1.0    0.0   0.0  298.0
5513.0   22254.0  464.0   24.0   3.0   0.0  41.0
0.0      409.0    3041.0  97.0   0.0   0.0  0.0
0.0      1.0      173.0   125.0  0.0   0.0  0.0
0.0      934.0    35.0    0.0    12.0  0.0  0.0
0.0      438.0    1267.0  99.0   0.0   0.0  0.0
1125.0   30.0     0.0     0.0    0.0   0.0  858.0
metrics.precision

结果:

res2: Double = 0.6949696938299746

参考:

  1. Spark 高级数据分析 Ryza 龚少成译
发布了521 篇原创文章 · 获赞 152 · 访问量 77万+

猜你喜欢

转载自blog.csdn.net/rosefun96/article/details/105511928