ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用
目录
基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用
相关文章
ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用
ML之PySpark:基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用实现代码
基于PySpark框架针对adult人口普查收入数据集结合Pipeline利用LoR/DT/RF算法(网格搜索+交叉验证评估+特征重要性)实现二分类预测(年收入是否超50k)案例应用
# 1、定义数据集
# 1.1、创建SparkSession连接
# 1.2、读取数据集
<class 'pyspark.sql.dataframe.DataFrame'>
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
|age| workclass|fnlwgt| education|education-num| marital-status| occupation| relationship| race| sex|capital-gain|capital-loss|hours-per-week|native-country|inc
ome|
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
| 39| State-gov| 77516| Bachelors| 13| Never-married| Adm-clerical| Not-in-family| White| Male| 2174| 0| 40| United-States| <=
50K|
| 50| Self-emp-not-inc| 83311| Bachelors| 13| Married-civ-spouse| Exec-managerial| Husband| White| Male| 0| 0| 13| United-States| <=
50K|
| 38| Private|215646| HS-grad| 9| Divorced| Handlers-cleaners| Not-in-family| White| Male| 0| 0| 40| United-States| <=
50K|
+---+-----------------+------+----------+-------------+-------------------+------------------+--------------+------+-----+------------+------------+--------------+--------------+---
---+
only showing top 3 rows
root
|-- age: integer (nullable = true)
|-- workclass: string (nullable = true)
|-- fnlwgt: integer (nullable = true)
|-- education: string (nullable = true)
|-- education-num: integer (nullable = true)
|-- marital-status: string (nullable = true)
|-- occupation: string (nullable = true)
|-- relationship: string (nullable = true)
|-- race: string (nullable = true)
|-- sex: string (nullable = true)
|-- capital-gain: integer (nullable = true)
|-- capital-loss: integer (nullable = true)
|-- hours-per-week: integer (nullable = true)
|-- native-country: string (nullable = true)
|-- income: string (nullable = true)
15 ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'nat
ive-country', 'income']
# 1.3、划分特征类型
cat_features 9 ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
num_features 6 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
num_int_features 6 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
num_double_features 0 []
<class 'pyspark.sql.dataframe.DataFrame'>
# 1.4、特征类型转换
# 2、数据预处理/特征工程
# 2.1、缺失值统计并填充
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country income
0 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561 32561
0
age 32561
workclass 32561
fnlwgt 32561
education 32561
education-num 32561
marital-status 32561
occupation 32561
relationship 32561
hours-per-week 32561
native-country 32561
income 32561
<class 'pyspark.sql.dataframe.DataFrame'>
# 2.2、定义特征编码
# 2.2.1、对【类别型】特征编码
# 2.2.2、对目标变量编码
# 2.3、定义特征向量化
# 2.4、使用pipeline完成数据处理
# 2.5、查看数据预处理后的数据集
label features age workclass fnlwgt ... capital-gain capital-loss hours-per-week native-country income
0 0.0 (0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ... 39 State-gov 77516 ... 2174 0 40 United-States <=50K
1 0.0 (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 50 Self-emp-not-inc 83311 ... 0 0 13 United-States <=50K
2 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ... 38 Private 215646 ... 0 0 40 United-States <=50K
3 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 53 Private 234721 ... 0 0 40 United-States <=50K
4 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 28 Private 338409 ... 0 0 40 Cuba <=50K
5 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 37 Private 284582 ... 0 0 40 United-States <=50K
6 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 49 Private 160187 ... 0 0 16 Jamaica <=50K
7 1.0 (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ... 52 Self-emp-not-inc 209642 ... 0 0 45 United-States >50K
8 1.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 31 Private 45781 ... 14084 0 50 United-States >50K
9 1.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 42 Private 159449 ... 5178 0 40 United-States >50K
10 1.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 37 Private 280464 ... 0 0 80 United-States >50K
11 1.0 (0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ... 30 State-gov 141297 ... 0 0 40 India >50K
12 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 23 Private 122272 ... 0 0 30 United-States <=50K
13 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 32 Private 205019 ... 0 0 50 United-States <=50K
14 1.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 40 Private 121772 ... 0 0 40 ? >50K
15 0.0 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 34 Private 245487 ... 0 0 45 Mexico <=50K
16 0.0 (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ... 25 Self-emp-not-inc 176756 ... 0 0 35 United-States <=50K
[17 rows x 17 columns]
# 3、模型训练与评估
# 3.1、数据集切分
train_sparksql_df: 24362
val_sparksql_df: 8199
# 3.2、模型训练
# 3.3、模型预测与评估
# 3.4、模型调参并评估:网格搜索+交叉验证
# 查看最优模型参数
LoR:0.90406
----------------LoR------------------
23/04/12 14:38:41 WARN ProcfsMetricsGetter: Exception when trying to compute pagesize, as a result reporting of ProcessTree metrics is stopped
AUC: 0.9031383924571014
+-----+------+----------+--------------------+---+
|label|income|prediction| probability|age|
+-----+------+----------+--------------------+---+
| 0.0| <=50K| 1.0|[0.43352058980891...| 32|
| 0.0| <=50K| 0.0|[0.75107062535890...| 29|
| 0.0| <=50K| 0.0|[0.63638300478043...| 33|
| 0.0| <=50K| 0.0|[0.51958003762662...| 37|
| 0.0| <=50K| 1.0|[0.34111125408907...| 41|
+-----+------+----------+--------------------+---+
only showing top 20 rows
23/04/12 14:39:24 WARN BlockManager: Asked to remove block broadcast_2148, which does not exist
modelCV2fit.bestModel LogisticRegressionModel: uid=LogisticRegression_a1b1ca1092dc, numClasses=2, numFeatures=100
AUC: 0.9040755818589695
DTC:0.7867
-----------------DTC------------------
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_334e98949900, depth=5, numNodes=29, numClasses=2, numFeatures=100
If (feature 23 in {0.0})
If (feature 97 <= 7559.0)
Predict: 0.0
Else (feature 97 > 7559.0)
If (feature 94 <= 21.5)
If (feature 44 in {0.0})
Predict: 0.0
Else (feature 44 not in {0.0})
Predict: 1.0
Else (feature 94 > 21.5)
Predict: 1.0
Else (feature 23 not in {0.0})
If (feature 96 <= 12.5)
If (feature 97 <= 7559.0)
Predict: 0.0
Else (feature 97 > 7559.0)
If (feature 94 <= 58.5)
Predict: 1.0
Else (feature 94 > 58.5)
If (feature 37 in {1.0})
Predict: 0.0
Else (feature 37 not in {1.0})
Predict: 1.0
Else (feature 96 > 12.5)
If (feature 97 <= 7559.0)
If (feature 98 <= 1862.0)
If (feature 99 <= 30.5)
Predict: 0.0
Else (feature 99 > 30.5)
Predict: 1.0
Else (feature 98 > 1862.0)
If (feature 76 in {1.0})
Predict: 0.0
Else (feature 76 not in {1.0})
Predict: 1.0
Else (feature 97 > 7559.0)
If (feature 39 in {1.0})
If (feature 1 in {1.0})
Predict: 0.0
Else (feature 1 not in {1.0})
Predict: 1.0
Else (feature 39 not in {1.0})
Predict: 1.0
root
|-- label: double (nullable = false)
|-- features: vector (nullable = true)
|-- age: integer (nullable = true)
|-- workclass: string (nullable = true)
|-- fnlwgt: integer (nullable = true)
|-- education: string (nullable = true)
|-- education-num: integer (nullable = true)
|-- marital-status: string (nullable = true)
|-- occupation: string (nullable = true)
|-- relationship: string (nullable = true)
|-- race: string (nullable = true)
|-- sex: string (nullable = true)
|-- capital-gain: integer (nullable = true)
|-- capital-loss: integer (nullable = true)
|-- hours-per-week: integer (nullable = true)
|-- native-country: string (nullable = true)
|-- income: string (nullable = true)
|-- rawPrediction: vector (nullable = true)
|-- probability: vector (nullable = true)
|-- prediction: double (nullable = false)
AUC: 0.5962768095770891
modelCV2fit.bestModel DecisionTreeClassificationModel: uid=DecisionTreeClassifier_334e98949900, depth=15, numNodes=1827, numClasses=2, numFeatures=100
numNodes = 1827
depth = 15
AUC: 0.7867756409676727
RFC:0.9107
-----------------RFC------------------
AUC: 0.8811337196636311
AUC: 0.910702086397356
modelCV2fit.bestModel RandomForestClassificationModel: uid=RandomForestClassifier_50ea87b083a8, numTrees=30, numClasses=2, numFeatures=100
idx name
0 94 age
1 95 fnlwgt
2 96 education-num
3 97 capital-gain
4 98 capital-loss
.. .. ...
95 89 native-country_one_hot_ Yugoslavia
96 90 native-country_one_hot_ Outlying-US(Guam-USVI-...
97 91 native-country_one_hot_ Honduras
98 92 native-country_one_hot_ Hungary
99 93 native-country_one_hot_ Scotland
[100 rows x 2 columns]
idx name feature_importance
97 91 native-country_one_hot_ Honduras 1.495439e-01
23 17 education_one_hot_ Prof-school 1.357221e-01
96 90 native-country_one_hot_ Outlying-US(Guam-USVI-... 9.849080e-02
43 37 occupation_one_hot_ Transport-moving 8.886728e-02
99 93 native-country_one_hot_ Scotland 6.417345e-02
.. .. ... ...
88 82 native-country_one_hot_ Ecuador 1.718100e-05
86 80 native-country_one_hot_ France 1.615439e-05
7 1 workclass_one_hot_ Self-emp-not-inc 3.834218e-06
91 85 native-country_one_hot_ Cambodia 2.279066e-07
90 84 native-country_one_hot_ Hong 0.000000e+00