ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例

ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用

目录

基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用

# 1、定义数据集

# 1.1、创建SparkSession连接

# 1.2、读取数据集

# 1.3、划分特征类型

# 1.4、特征类型转换

# 2、数据预处理/特征工程

# 2.1、缺失值统计并填充

# 2.2、定义特征编码

# 2.2.1、对【类别型】特征编码

# 2.2.2、对目标变量编码

# 2.3、定义特征向量化

# 2.4、使用pipeline完成数据处理

# 2.5、查看数据预处理后的数据集

# 3、模型训练与评估

# 3.1、数据集切分

# 3.2、模型训练

# 3.3、模型预测与评估

# 3.4、模型调参并评估:网格搜索+交叉验证

# 查看最优模型参数

LoR:0.90406

DTC:0.7867

RFC:0.9107


相关文章
ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用
ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用实现代码

基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用

# 1、定义数据集

# 1.1、创建SparkSession连接

# 1.2、读取数据集

<class 'pyspark.sql.dataframe.DataFrame'>
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
|age|        workclass|fnlwgt| education|education-num|     marital-status|        occupation|  relationship|  race|  sex|capital-gain|capital-loss|hours-per-week|native-country|inc
ome|
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
| 39|        State-gov| 77516| Bachelors|           13|      Never-married|      Adm-clerical| Not-in-family| White| Male|        2174|           0|            40| United-States| <=
50K|
| 50| Self-emp-not-inc| 83311| Bachelors|           13| Married-civ-spouse|   Exec-managerial|       Husband| White| Male|           0|           0|            13| United-States| <=
50K|
| 38|          Private|215646|   HS-grad|            9|           Divorced| Handlers-cleaners| Not-in-family| White| Male|           0|           0|            40| United-States| <=
50K|
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
only showing top 3 rows

root
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: integer (nullable = true)
 |-- education: string (nullable = true)
 |-- education-num: integer (nullable = true)
 |-- marital-status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital-gain: integer (nullable = true)
 |-- capital-loss: integer (nullable = true)
 |-- hours-per-week: integer (nullable = true)
 |-- native-country: string (nullable = true)
 |-- income: string (nullable = true)

15 ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'nat
ive-country', 'income']

# 1.3、划分特征类型

cat_features 9 ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
num_features 6 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
num_int_features 6 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
num_double_features 0 []
<class 'pyspark.sql.dataframe.DataFrame'>

# 1.4、特征类型转换

# 2、数据预处理/特征工程

# 2.1、缺失值统计并填充

     age  workclass  fnlwgt  education  education-num  marital-status  occupation  relationship   race    sex  capital-gain  capital-loss  hours-per-week  native-country  income
0  32561      32561   32561      32561          32561           32561       32561         32561  32561  32561         32561         32561           32561           32561   32561    
                    0
age             32561
workclass       32561
fnlwgt          32561
education       32561
education-num   32561
marital-status  32561
occupation      32561
relationship    32561
hours-per-week  32561
native-country  32561
income          32561
<class 'pyspark.sql.dataframe.DataFrame'>

# 2.2、定义特征编码

# 2.2.1、对【类别型】特征编码

# 2.2.2、对目标变量编码

# 2.3、定义特征向量化

# 2.4、使用pipeline完成数据处理

# 2.5、查看数据预处理后的数据集

    label                                           features  age          workclass  fnlwgt  ... capital-gain  capital-loss hours-per-week  native-country  income
0     0.0  (0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...   39          State-gov   77516  ...         2174             0             40   United-States   <=50K
1     0.0  (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   50   Self-emp-not-inc   83311  ...            0             0             13   United-States   <=50K
2     0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   38            Private  215646  ...            0             0             40   United-States   <=50K
3     0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   53            Private  234721  ...            0             0             40   United-States   <=50K
4     0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   28            Private  338409  ...            0             0             40            Cuba   <=50K
5     0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   37            Private  284582  ...            0             0             40   United-States   <=50K
6     0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   49            Private  160187  ...            0             0             16         Jamaica   <=50K
7     1.0  (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   52   Self-emp-not-inc  209642  ...            0             0             45   United-States    >50K
8     1.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   31            Private   45781  ...        14084             0             50   United-States    >50K
9     1.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   42            Private  159449  ...         5178             0             40   United-States    >50K
10    1.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   37            Private  280464  ...            0             0             80   United-States    >50K
11    1.0  (0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...   30          State-gov  141297  ...            0             0             40           India    >50K
12    0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   23            Private  122272  ...            0             0             30   United-States   <=50K
13    0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   32            Private  205019  ...            0             0             50   United-States   <=50K
14    1.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   40            Private  121772  ...            0             0             40               ?    >50K
15    0.0  (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   34            Private  245487  ...            0             0             45          Mexico   <=50K
16    0.0  (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...   25   Self-emp-not-inc  176756  ...            0             0             35   United-States   <=50K

[17 rows x 17 columns]

# 3、模型训练与评估

# 3.1、数据集切分

train_sparksql_df:  24362                                                                           
val_sparksql_df:  8199

# 3.2、模型训练

# 3.3、模型预测与评估

# 3.4、模型调参并评估:网格搜索+交叉验证

# 查看最优模型参数

LoR:0.90406

----------------LoR------------------
23/04/12 14:38:41 WARN ProcfsMetricsGetter: Exception when trying to compute pagesize, as a result reporting of ProcessTree metrics is stopped
AUC: 0.9031383924571014                                                        
+-----+------+----------+--------------------+---+
|label|income|prediction|         probability|age|
+-----+------+----------+--------------------+---+
|  0.0| <=50K|       1.0|[0.43352058980891...| 32|
|  0.0| <=50K|       0.0|[0.75107062535890...| 29|
|  0.0| <=50K|       0.0|[0.63638300478043...| 33|
|  0.0| <=50K|       0.0|[0.51958003762662...| 37|
|  0.0| <=50K|       1.0|[0.34111125408907...| 41|
+-----+------+----------+--------------------+---+
only showing top 20 rows

23/04/12 14:39:24 WARN BlockManager: Asked to remove block broadcast_2148, which does not exist
modelCV2fit.bestModel LogisticRegressionModel: uid=LogisticRegression_a1b1ca1092dc, numClasses=2, numFeatures=100
AUC: 0.9040755818589695

DTC:0.7867

-----------------DTC------------------
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_334e98949900, depth=5, numNodes=29, numClasses=2, numFeatures=100
  If (feature 23 in {0.0})
   If (feature 97 <= 7559.0)
    Predict: 0.0
   Else (feature 97 > 7559.0)
    If (feature 94 <= 21.5)
     If (feature 44 in {0.0})
      Predict: 0.0
     Else (feature 44 not in {0.0})
      Predict: 1.0
    Else (feature 94 > 21.5)
     Predict: 1.0
  Else (feature 23 not in {0.0})
   If (feature 96 <= 12.5)
    If (feature 97 <= 7559.0)
     Predict: 0.0
    Else (feature 97 > 7559.0)
     If (feature 94 <= 58.5)
      Predict: 1.0
     Else (feature 94 > 58.5)
      If (feature 37 in {1.0})
       Predict: 0.0
      Else (feature 37 not in {1.0})
       Predict: 1.0
   Else (feature 96 > 12.5)
    If (feature 97 <= 7559.0)
     If (feature 98 <= 1862.0)
      If (feature 99 <= 30.5)
       Predict: 0.0
      Else (feature 99 > 30.5)
       Predict: 1.0
     Else (feature 98 > 1862.0)
      If (feature 76 in {1.0})
       Predict: 0.0
      Else (feature 76 not in {1.0})
       Predict: 1.0
    Else (feature 97 > 7559.0)
     If (feature 39 in {1.0})
      If (feature 1 in {1.0})
       Predict: 0.0
      Else (feature 1 not in {1.0})
       Predict: 1.0
     Else (feature 39 not in {1.0})
      Predict: 1.0

root
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: integer (nullable = true)
 |-- education: string (nullable = true)
 |-- education-num: integer (nullable = true)
 |-- marital-status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital-gain: integer (nullable = true)
 |-- capital-loss: integer (nullable = true)
 |-- hours-per-week: integer (nullable = true)
 |-- native-country: string (nullable = true)
 |-- income: string (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

AUC: 0.5962768095770891
modelCV2fit.bestModel DecisionTreeClassificationModel: uid=DecisionTreeClassifier_334e98949900, depth=15, numNodes=1827, numClasses=2, numFeatures=100
numNodes =  1827
depth =  15
AUC: 0.7867756409676727

RFC:0.9107

-----------------RFC------------------
AUC: 0.8811337196636311   
AUC: 0.910702086397356  
modelCV2fit.bestModel RandomForestClassificationModel: uid=RandomForestClassifier_50ea87b083a8, numTrees=30, numClasses=2, numFeatures=100

   idx                                               name
0   94                                                age
1   95                                             fnlwgt
2   96                                      education-num
3   97                                       capital-gain
4   98                                       capital-loss
..  ..                                                ...
95  89                 native-country_one_hot_ Yugoslavia
96  90  native-country_one_hot_ Outlying-US(Guam-USVI-...
97  91                   native-country_one_hot_ Honduras
98  92                    native-country_one_hot_ Hungary
99  93                   native-country_one_hot_ Scotland

[100 rows x 2 columns]
   idx                                               name  feature_importance
97  91                   native-country_one_hot_ Honduras        1.495439e-01
23  17                     education_one_hot_ Prof-school        1.357221e-01
96  90  native-country_one_hot_ Outlying-US(Guam-USVI-...        9.849080e-02
43  37               occupation_one_hot_ Transport-moving        8.886728e-02
99  93                   native-country_one_hot_ Scotland        6.417345e-02
..  ..                                                ...                 ...
88  82                    native-country_one_hot_ Ecuador        1.718100e-05
86  80                     native-country_one_hot_ France        1.615439e-05
7    1                workclass_one_hot_ Self-emp-not-inc        3.834218e-06
91  85                   native-country_one_hot_ Cambodia        2.279066e-07
90  84                       native-country_one_hot_ Hong        0.000000e+00

猜你喜欢

转载自blog.csdn.net/qq_41185868/article/details/130113857