SparkML -- LightGBM On Spark 重要性LightGBMRanker示例

MAVEN

<dependency>
     <groupId>com.microsoft.ml.spark</groupId>
     <artifactId>mmlspark_2.11</artifactId>
     <version>0.18.0</version>
 </dependency>
 <dependency>
     <groupId>com.microsoft.ml.lightgbm</groupId>
     <artifactId>lightgbmlib</artifactId>
     <version>2.2.350</version>
 </dependency>

测试数据

http://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip

hour.csv和day.csv都有如下属性,除了hour.csv文件中没有hr属性以外

  • instant: 记录ID
  • dteday : 时间日期
  • season : 季节 (1:春季, 2:夏季, 3:秋季, 4:冬季)
  • yr : 年份 (0: 2011, 1:2012)
  • mnth : 月份 ( 1 to 12)
  • hr : 当天时刻 (0 to 23)
  • holiday : 当天是否是节假日(extracted from http://dchr.dc.gov/page/holiday-schedule)
  • weekday : 周几
  • workingday : 工作日 is 1, 其他 is 0.
  • weathersit : 天气
  • 1: Clear, Few clouds, Partly cloudy, Partly cloudy
  • 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
  • 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
  • 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
  • temp : 气温 Normalized temperature in Celsius. The values are divided to 41 (max)
  • atemp: 体感温度 Normalized feeling temperature in Celsius. The values are divided to 50 (max)
  • hum: 湿度 Normalized humidity. The values are divided to 100 (max)
  • windspeed: 风速Normalized wind speed. The values are divided to 67 (max)
  • casual: 临时用户数count of casual users
  • registered: 注册用户数count of registered users
  • cnt: 目标变量,每小时的自行车的租用量,包括临时用户和注册用户count of total rental bikes including both casual and registered

代码示例

package com.bigblue.lightgbm

import com.microsoft.ml.spark.lightgbm.{LightGBMRanker, LightGBMRankerModel}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * Created By TheBigBlue on 2020/3/4
 * Description :
 */
object LightGBMRankerTest {

  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder().appName("test-lightgbm").master("local[2]").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    val originalData: DataFrame = spark.read.option("header", "true") //第一行作为Schema
      .option("inferSchema", "true") //推测schema类型
      //      .csv("/home/hdfs/hour.csv")
      .csv("file:///D:/Cache/ProgramCache/TestData/dataSource/lightgbm/hour.csv")

    val labelCol = "workingday"
    //离散列
    val cateCols = Array[String]("season", "yr", "mnth", "hr")
    // 连续列
    val conCols: Array[String] = Array("temp", "atemp", "hum", "casual", "cnt")
    //feature列
    val vecCols = conCols ++ cateCols

    import spark.implicits._
    var inputDF = originalData.select(labelCol, vecCols: _*)
    vecCols.foreach(col => {
      inputDF = inputDF.withColumn(col, $"$col".cast(DoubleType))
    })
    inputDF = inputDF.withColumn(labelCol, $"$labelCol".cast(IntegerType))

    //追加一列index列作为groupCol,不指定groupCol报错
    import org.apache.spark.sql.functions._
    inputDF = inputDF.withColumn("index", monotonically_increasing_id)
//    val structType: StructType = inputDF.schema.add(StructField("index", LongType))
//    val zipRDD: RDD[Row] = inputDF.rdd.zipWithIndex().map(tp => Row.merge(tp._1, Row(tp._2)))
//    val fitDF = spark.createDataFrame(zipRDD, structType)
    inputDF.show

    val assembler = new VectorAssembler().setInputCols(vecCols).setOutputCol("features")

	//必须设定groupCol
    val classifier: LightGBMRanker = new LightGBMRanker().setNumIterations(100).setNumLeaves(31)
      .setBoostFromAverage(false).setFeatureFraction(1.0).setMaxDepth(-1).setMaxBin(255)
      .setLearningRate(0.1).setMinSumHessianInLeaf(0.001).setLambdaL1(0.0).setLambdaL2(0.0)
      .setBaggingFraction(1.0).setBaggingFreq(0).setBaggingSeed(1).setObjective("lambdarank")
      .setLabelCol(labelCol).setCategoricalSlotNames(cateCols).setFeaturesCol("features")
      .setGroupCol("index").setBoostingType("gbdt")	

    val pipelineModel = new Pipeline().setStages(Array(assembler, classifier)).fit(inputDF)
    val rankerModel = pipelineModel.stages(1).asInstanceOf[LightGBMRankerModel]
    val importanceValues = rankerModel.getFeatureImportances("split")
    //排序取前百分之
    val filteredTuples = vecCols.zip(importanceValues).sortWith(_._2 > _._2)
          .take((0.6 * vecCols.size).intValue())
    //生成重要性df
    var index = 0
    val importanceRDD: Array[LightGBMRankerTest] = filteredTuples.map(tuple => {
      index += 1
      LightGBMRankerTest(index, tuple._1, tuple._2)
    })
    val importanceDF = spark.createDataFrame(importanceRDD)
    importanceDF.show

	//过滤后的特征数据
    val filteredCols: Array[String] = filteredTuples.map(_._1)
    val finalDF = inputDF.select(labelCol, filteredCols: _*)
    finalDF.show
  }
}
case class LightGBMRankerTest(id: Long, feature_name: String, value: Double)

结果

在这里插入图片描述

发布了73 篇原创文章 · 获赞 18 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/Aeve_imp/article/details/105049048