机器学习:考试预测实战(特征隐射,独热编码,特征重要性选择,网格搜索调参)

声明:内容非原创,代码来自葁sir

import numpy as np
import pandas as pd
from pandas import Series,DataFrame
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
# 导数据
samples = pd.read_csv('data/student-data.csv')
samples.head()
school sex age address famsize Pstatus Medu Fedu Mjob Fjob ... higher internet romantic famrel freetime goout Dalc Walc health passed
0 GP F 18 U GT3 A 4 4 at_home teacher ... yes no no 4 3 4 1 1 3 no
1 GP F 17 U GT3 T 1 1 at_home other ... yes yes no 5 3 3 1 1 3 no
2 GP F 15 U LE3 T 1 1 at_home other ... yes yes no 4 3 2 2 3 3 yes
3 GP F 15 U GT3 T 4 2 health services ... yes yes yes 3 2 2 1 1 5 yes
4 GP F 16 U GT3 T 3 3 other other ... yes no no 4 3 2 1 2 5 yes

5 rows × 30 columns

samples.dtypes
school        object
sex           object
age            int64
address       object
famsize       object
Pstatus       object
Medu           int64
Fedu           int64
Mjob          object
Fjob          object
reason        object
guardian      object
traveltime     int64
studytime      int64
failures       int64
schoolsup     object
famsup        object
paid          object
activities    object
nursery       object
higher        object
internet      object
romantic      object
famrel         int64
freetime       int64
goout          int64
Dalc           int64
Walc           int64
health         int64
passed        object
dtype: object
samples.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 395 entries, 0 to 394
Data columns (total 30 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   school      395 non-null    object
 1   sex         395 non-null    object
 2   age         395 non-null    int64 
 3   address     395 non-null    object
 4   famsize     395 non-null    object
 5   Pstatus     395 non-null    object
 6   Medu        395 non-null    int64 
 7   Fedu        395 non-null    int64 
 8   Mjob        395 non-null    object
 9   Fjob        395 non-null    object
 10  reason      395 non-null    object
 11  guardian    395 non-null    object
 12  traveltime  395 non-null    int64 
 13  studytime   395 non-null    int64 
 14  failures    395 non-null    int64 
 15  schoolsup   395 non-null    object
 16  famsup      395 non-null    object
 17  paid        395 non-null    object
 18  activities  395 non-null    object
 19  nursery     395 non-null    object
 20  higher      395 non-null    object
 21  internet    395 non-null    object
 22  romantic    395 non-null    object
 23  famrel      395 non-null    int64 
 24  freetime    395 non-null    int64 
 25  goout       395 non-null    int64 
 26  Dalc        395 non-null    int64 
 27  Walc        395 non-null    int64 
 28  health      395 non-null    int64 
 29  passed      395 non-null    object
dtypes: int64(12), object(18)
memory usage: 92.7+ KB
# 看一眼特征情况
samples.columns.tolist()
['school',
 'sex',
 'age',
 'address',
 'famsize',
 'Pstatus',
 'Medu',
 'Fedu',
 'Mjob',
 'Fjob',
 'reason',
 'guardian',
 'traveltime',
 'studytime',
 'failures',
 'schoolsup',
 'famsup',
 'paid',
 'activities',
 'nursery',
 'higher',
 'internet',
 'romantic',
 'famrel',
 'freetime',
 'goout',
 'Dalc',
 'Walc',
 'health',
 'passed']
# 查看统计学指标 只看numeric类型的指标 说明有些确实不是数字类型
samples.describe().columns.tolist()
['age',
 'Medu',
 'Fedu',
 'traveltime',
 'studytime',
 'failures',
 'famrel',
 'freetime',
 'goout',
 'Dalc',
 'Walc',
 'health']
samples.describe([0.01,0.99]).T
count mean std min 1% 50% 99% max
age 395.0 16.696203 1.276043 15.0 15.0 17.0 20.0 22.0
Medu 395.0 2.749367 1.094735 0.0 1.0 3.0 4.0 4.0
Fedu 395.0 2.521519 1.088201 0.0 1.0 2.0 4.0 4.0
traveltime 395.0 1.448101 0.697505 1.0 1.0 1.0 4.0 4.0
studytime 395.0 2.035443 0.839240 1.0 1.0 2.0 4.0 4.0
failures 395.0 0.334177 0.743651 0.0 0.0 0.0 3.0 3.0
famrel 395.0 3.944304 0.896659 1.0 1.0 4.0 5.0 5.0
freetime 395.0 3.235443 0.998862 1.0 1.0 3.0 5.0 5.0
goout 395.0 3.108861 1.113278 1.0 1.0 3.0 5.0 5.0
Dalc 395.0 1.481013 0.890741 1.0 1.0 1.0 5.0 5.0
Walc 395.0 2.291139 1.287897 1.0 1.0 2.0 5.0 5.0
health 395.0 3.554430 1.390303 1.0 1.0 4.0 5.0 5.0
# 量纲不统一 需无量纲处理
# 数据不存在明显异常数据 可以直接做正态分布转换

import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
# 查看数据缺失情况
# 也可以用.mean()来计算缺失值的占比
# 如果缺失值占比特别高 可以考虑填充 (如果比例上升到10% 样本的数量太少的话 就不要考虑删除)
# 如果比例很低 可以直接删除缺失样本(能不删 就不删)
plt.figure(figsize=(12,4))
samples.notnull().mean().plot(kind='bar')
<AxesSubplot:>

在这里插入图片描述

# LabelEncoder 将一列文本数据转化为数值
# OneHotEncoder 列文本数据转化为一列或者多列的 只有0 和1 的数据
# map映射
data = ['男','女','男']
# 字典映射
map_dict = {
    
    '男':1, '女':0}
Series(data).map(map_dict)
0    1
1    0
2    1
dtype: int64
# 函数映射
f = lambda x: (x=='男')*1
Series(data).map(f)
0    1
1    0
2    1
dtype: int64
# 独热编码方法
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
OneHotEncoder().fit_transform(LabelEncoder().fit_transform(data).reshape(-1,1)).toarray()
array([[0., 1.],
       [1., 0.],
       [0., 1.]])
m_data =OneHotEncoder().fit_transform(LabelEncoder().fit_transform(data).reshape(-1,1)).toarray()
# 做X,y
X= samples.iloc[:,:-1].copy()
y = samples.iloc[:,-1].copy()
# 判断一下哪些字段不是数值类型
for col_name in X:
    data = X[col_name]
    if data.dtype == 'object':
        print(col_name,'---->',data.unique())
school ----> ['GP' 'MS']
sex ----> ['F' 'M']
address ----> ['U' 'R']
famsize ----> ['GT3' 'LE3']
Pstatus ----> ['A' 'T']
Mjob ----> ['at_home' 'health' 'other' 'services' 'teacher']
Fjob ----> ['teacher' 'other' 'services' 'health' 'at_home']
reason ----> ['course' 'other' 'home' 'reputation']
guardian ----> ['mother' 'father' 'other']
schoolsup ----> ['yes' 'no']
famsup ----> ['no' 'yes']
paid ----> ['no' 'yes']
activities ----> ['no' 'yes']
nursery ----> ['yes' 'no']
higher ----> ['yes' 'no']
internet ----> ['no' 'yes']
romantic ----> ['no' 'yes']
# 数据映射应该考虑实际的大小关系 (对预测目标的影响力度)

for col_name in X:
    data = X[col_name]
    if data.dtype == 'object':
        X[col_name] = LabelEncoder().fit_transform(data)
X
school sex age address famsize Pstatus Medu Fedu Mjob Fjob ... nursery higher internet romantic famrel freetime goout Dalc Walc health
0 0 0 18 1 0 0 4 4 0 4 ... 1 1 0 0 4 3 4 1 1 3
1 0 0 17 1 0 1 1 1 0 2 ... 0 1 1 0 5 3 3 1 1 3
2 0 0 15 1 1 1 1 1 0 2 ... 1 1 1 0 4 3 2 2 3 3
3 0 0 15 1 0 1 4 2 1 3 ... 1 1 1 1 3 2 2 1 1 5
4 0 0 16 1 0 1 3 3 2 2 ... 1 1 0 0 4 3 2 1 2 5
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
390 1 1 20 1 1 0 2 2 3 3 ... 1 1 0 0 5 5 4 4 5 4
391 1 1 17 1 1 1 3 1 3 3 ... 0 1 1 0 2 4 5 3 4 2
392 1 1 21 0 0 1 1 1 2 2 ... 0 1 0 0 5 5 3 3 3 3
393 1 1 18 0 1 1 3 2 3 2 ... 0 1 1 0 4 4 1 3 4 5
394 1 1 19 1 1 1 1 1 2 0 ... 1 1 1 0 3 2 3 3 3 5

395 rows × 29 columns

# 无量纲处理
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler().fit_transform(X)
ss_X = DataFrame(data=ss_X, columns=X.columns)
ss_X
school sex age address famsize Pstatus Medu Fedu Mjob Fjob ... nursery higher internet romantic famrel freetime goout Dalc Walc health
0 -0.363050 -0.948176 1.023046 0.535392 -0.636941 -2.938392 1.143856 1.360371 -1.769793 1.993149 ... 0.507899 0.23094 -2.232677 -0.708450 0.062194 -0.236010 0.801479 -0.540699 -1.003789 -0.399289
1 -0.363050 -0.948176 0.238380 0.535392 -0.636941 0.340322 -1.600009 -1.399970 -1.769793 -0.325831 ... -1.968894 0.23094 0.447893 -0.708450 1.178860 -0.236010 -0.097908 -0.540699 -1.003789 -0.399289
2 -0.363050 -0.948176 -1.330954 0.535392 1.570004 0.340322 -1.600009 -1.399970 -1.769793 -0.325831 ... 0.507899 0.23094 0.447893 -0.708450 0.062194 -0.236010 -0.997295 0.583385 0.551100 -0.399289
3 -0.363050 -0.948176 -1.330954 0.535392 -0.636941 0.340322 1.143856 -0.479857 -0.954077 0.833659 ... 0.507899 0.23094 0.447893 1.411533 -1.054472 -1.238419 -0.997295 -0.540699 -1.003789 1.041070
4 -0.363050 -0.948176 -0.546287 0.535392 -0.636941 0.340322 0.229234 0.440257 -0.138362 -0.325831 ... 0.507899 0.23094 -2.232677 -0.708450 0.062194 -0.236010 -0.997295 -0.540699 -0.226345 1.041070
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
390 2.754443 1.054656 2.592380 0.535392 1.570004 -2.938392 -0.685387 -0.479857 0.677354 0.833659 ... 0.507899 0.23094 -2.232677 -0.708450 1.178860 1.768808 0.801479 2.831553 2.105989 0.320890
391 2.754443 1.054656 0.238380 0.535392 1.570004 0.340322 0.229234 -1.399970 0.677354 0.833659 ... -1.968894 0.23094 0.447893 -0.708450 -2.171138 0.766399 1.700867 1.707469 1.328545 -1.119469
392 2.754443 1.054656 3.377047 -1.867789 -0.636941 0.340322 -1.600009 -1.399970 -0.138362 -0.325831 ... -1.968894 0.23094 -2.232677 -0.708450 1.178860 1.768808 -0.097908 1.707469 0.551100 -0.399289
393 2.754443 1.054656 1.023046 -1.867789 1.570004 0.340322 0.229234 -0.479857 0.677354 -0.325831 ... -1.968894 0.23094 0.447893 -0.708450 0.062194 0.766399 -1.896683 1.707469 1.328545 1.041070
394 2.754443 1.054656 1.807713 0.535392 1.570004 0.340322 -1.600009 -1.399970 -0.138362 -2.644812 ... 0.507899 0.23094 0.447893 -0.708450 -1.054472 -1.238419 -0.097908 1.707469 0.551100 1.041070

395 rows × 29 columns

ss_X.describe([0.01,0.99]).T
count mean std min 1% 50% 99% max
school 395.0 1.866299e-16 1.001268 -0.363050 -0.363050 -0.363050 2.754443 2.754443
sex 395.0 -4.834389e-17 1.001268 -0.948176 -0.948176 -0.948176 1.054656 1.054656
age 395.0 1.411529e-15 1.001268 -1.330954 -1.330954 0.238380 2.592380 4.161713
address 395.0 6.998621e-17 1.001268 -1.867789 -1.867789 0.535392 0.535392 0.535392
famsize 395.0 1.281675e-16 1.001268 -0.636941 -0.636941 -0.636941 1.570004 1.570004
Pstatus 395.0 -1.503720e-16 1.001268 -2.938392 -2.938392 0.340322 0.340322 0.340322
Medu 395.0 8.432074e-18 1.001268 -2.514630 -1.600009 0.229234 1.143856 1.143856
Fedu 395.0 -1.264811e-16 1.001268 -2.320084 -1.399970 -0.479857 1.360371 1.360371
Mjob 395.0 -1.158707e-16 1.001268 -1.769793 -1.769793 -0.138362 1.493069 1.493069
Fjob 395.0 -1.607715e-16 1.001268 -2.644812 -2.644812 -0.325831 1.993149 1.993149
reason 395.0 4.384678e-17 1.001268 -1.040599 -1.040599 -0.211896 1.445509 1.445509
guardian 395.0 3.091760e-17 1.001268 -1.591714 -1.591714 0.273945 2.139603 2.139603
traveltime 395.0 -2.203582e-16 1.001268 -0.643249 -0.643249 -0.643249 3.663251 3.663251
studytime 395.0 -2.709506e-16 1.001268 -1.235351 -1.235351 -0.042286 2.343844 2.343844
failures 395.0 -2.599889e-16 1.001268 -0.449944 -0.449944 -0.449944 3.589323 3.589323
schoolsup 395.0 -1.577360e-15 1.001268 -0.385040 -0.385040 -0.385040 2.597133 2.597133
famsup 395.0 8.207218e-17 1.001268 -1.257656 -1.257656 0.795130 0.795130 0.795130
paid 395.0 3.766326e-17 1.001268 -0.919671 -0.919671 -0.919671 1.087346 1.087346
activities 395.0 5.846238e-17 1.001268 -1.017881 -1.017881 0.982433 0.982433 0.982433
nursery 395.0 -8.966105e-17 1.001268 -1.968894 -1.968894 0.507899 0.507899 0.507899
higher 395.0 8.923945e-17 1.001268 -4.330127 -4.330127 0.230940 0.230940 0.230940
internet 395.0 1.767925e-16 1.001268 -2.232677 -2.232677 0.447893 0.447893 0.447893
romantic 395.0 -1.062441e-16 1.001268 -0.708450 -0.708450 -0.708450 1.411533 1.411533
famrel 395.0 -1.410967e-16 1.001268 -3.287804 -3.287804 0.062194 1.178860 1.178860
freetime 395.0 1.028713e-16 1.001268 -2.240828 -2.240828 -0.236010 1.768808 1.768808
goout 395.0 -2.062345e-17 1.001268 -1.896683 -1.896683 -0.097908 1.700867 1.700867
Dalc 395.0 8.769357e-17 1.001268 -0.540699 -0.540699 -0.540699 3.955638 3.955638
Walc 395.0 -3.091760e-17 1.001268 -1.003789 -1.003789 -0.226345 2.105989 2.105989
health 395.0 1.169248e-16 1.001268 -1.839649 -1.839649 0.320890 1.041070 1.041070
X_train, X_test, y_train,y_test = train_test_split(ss_X,y,test_size=0.2,random_state=1)
knn = KNeighborsClassifier()
knn.fit(X_train,y_train)
knn.score(X_test,y_test)
0.6962025316455697
lr = LogisticRegression()
lr.fit(X_train,y_train)
lr.score(X_test,y_test)
0.7088607594936709
# 进一步: 特征选择
# coef_
lr.coef_
array([[ 0.03520344,  0.22443684, -0.36334047,  0.03790903,  0.23429034,
        -0.14052023,  0.1929704 ,  0.0183496 , -0.24143   ,  0.09296691,
         0.19759787, -0.17086501, -0.05668839,  0.28625693, -0.52088123,
        -0.29399977, -0.33164972,  0.12570786, -0.17226849, -0.10907161,
         0.35411433,  0.07470418, -0.19945335,  0.11104188,  0.16833114,
        -0.41107393, -0.05262516,  0.29150008, -0.02790423]])
plt.figure(figsize=(12,4))
np.abs(Series(data=lr.coef_[0], index=X.columns)).sort_values(ascending=False).plot(kind='bar')
<AxesSubplot:>

在这里插入图片描述

# 可以通过多种方式进行特征筛选 0.1 前几个 平均值以上的 
# 为什么逻辑斯蒂回归自带 l1正则项 l2正则项 自带系数选择的效果
np.abs(Series(data=lr.coef_[0],index=X.columns)).mean()
0.190936289299829
# 使用系数筛选的过程 本质上也是要不断尝试
# 也可以切换默认的正则项
lr = LogisticRegression(penalty='l1',solver='liblinear')
lr.fit(X_train,y_train)
LogisticRegression(penalty='l1', solver='liblinear')
# L1结果
lr.score(X_test,y_test)
0.7088607594936709
# 基于L1正则项 并且找出系数大于平均数的特征
# 思路:编写条件 用条件筛选
condition = np.abs(Series(data=lr.coef_[0],index=X.columns)) > np.abs(Series(data=lr.coef_[0],index=X.columns)).mean()
np.abs(Series(data=lr.coef_[0],index=X.columns))[condition].index
Index(['sex', 'age', 'famsize', 'Mjob', 'reason', 'studytime', 'failures',
       'schoolsup', 'famsup', 'higher', 'romantic', 'goout', 'Walc'],
      dtype='object')
# 基于L2做的训练和得分
lr = LogisticRegression()
lr.fit(X_train,y_train)
lr.score(X_test,y_test)
0.7088607594936709
# 基于L2做的特征选择
condition = np.abs(Series(data=lr.coef_[0],index=X.columns)) > np.abs(Series(data=lr.coef_[0],index=X.columns)).mean()
np.abs(Series(data=lr.coef_[0],index=X.columns))[condition].index
Index(['sex', 'age', 'famsize', 'Medu', 'Mjob', 'reason', 'studytime',
       'failures', 'schoolsup', 'famsup', 'higher', 'romantic', 'goout',
       'Walc'],
      dtype='object')
l1_index = np.array(['sex', 'age', 'famsize', 'Mjob', 'reason', 'studytime', 'failures',
       'schoolsup', 'famsup', 'higher', 'romantic', 'goout', 'Walc'])
l2_index = np.array(['sex', 'age', 'famsize', 'Medu', 'Mjob', 'reason', 'studytime',
       'failures', 'schoolsup', 'famsup', 'higher', 'romantic', 'goout',
       'Walc'])
importance_feature = list(set(l1_index) & set(l2_index))
importance_feature
['schoolsup',
 'goout',
 'studytime',
 'sex',
 'famsup',
 'age',
 'Mjob',
 'famsize',
 'failures',
 'romantic',
 'higher',
 'reason',
 'Walc']
good_X = X[importance_feature]
good_X
schoolsup goout studytime sex famsup age Mjob famsize failures romantic higher reason Walc
0 1 4 2 0 0 18 0 0 0 0 1 0 1
1 0 3 2 0 1 17 0 0 0 0 1 0 1
2 1 2 2 0 0 15 0 1 3 0 1 2 3
3 0 2 3 0 1 15 1 0 0 1 1 1 1
4 0 2 2 0 1 16 2 0 0 0 1 1 2
... ... ... ... ... ... ... ... ... ... ... ... ... ...
390 0 4 2 1 1 20 3 1 2 0 1 0 5
391 0 5 1 1 0 17 3 1 0 0 1 0 4
392 0 3 1 1 0 21 2 0 3 0 1 0 3
393 0 1 1 1 0 18 3 1 0 0 1 0 4
394 0 3 1 1 0 19 2 1 0 0 1 0 3

395 rows × 13 columns

X_train,X_test,y_train,y_test = train_test_split(good_X,y,test_size=0.2,random_state=1)
lr = LogisticRegression()
lr.fit(X_train,y_train)
D:\software\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(





LogisticRegression()
lr.score(X_test,y_test)
0.7341772151898734

算法调参

# K线对折(CV) 保证所有的样本数据都有成为训练数据和测试数据的机会
# cross_val_score 基于k折线的评分函数 把所有的按照指定k折的数据集 进行训练和预测 得到一组评分
# GridSearch 网格调参 基于k折线拆分 把所有的参数集中找到最好的一组数据
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
kfold = KFold(n_splits=3) # 拆三份
kfold.split(good_X,y) # 生成器
<generator object _BaseKFold.split at 0x000001BF8CD4EE40>
# 实际上是返回了三组索引
generator = kfold.split(good_X,y)
for g in generator:
    print(g)
(array([132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
       184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
       197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
       210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
       223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235,
       236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
       249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261,
       262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
       275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287,
       288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300,
       301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313,
       314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326,
       327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
       340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352,
       353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
       366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
       379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391,
       392, 393, 394]), array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131]))
(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
       275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287,
       288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300,
       301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313,
       314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326,
       327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
       340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352,
       353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
       366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
       379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391,
       392, 393, 394]), array([132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
       184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
       197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
       210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
       223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235,
       236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
       249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261,
       262, 263]))
(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181,
       182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194,
       195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
       208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220,
       221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233,
       234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246,
       247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259,
       260, 261, 262, 263]), array([264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276,
       277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289,
       290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302,
       303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
       316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328,
       329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341,
       342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354,
       355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
       368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380,
       381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393,
       394]))
X.shape
(395, 29)
lr = LogisticRegression()
generator = kfold.split(good_X,y)
temp = None
# 得到的迭代器对象 每一个对象,是一个元组(train_index,test_index)
for g in generator:
   # print('g0:',g[0])
    #print('g1',g[1])
    X_train = good_X.loc[g[0]]
    X_test = good_X.loc[g[1]]
    y_train = y[g[0]]
    y_test = y[g[1]]
    lr.fit(X_train,y_train)
    print(lr.score(X_test,y_test))
0.7954545454545454
0.6439393939393939
0.5801526717557252
# 使用cross_val_score 来进行预测评分
# estimator 算法对象
# X, y = None 特征向量集合 标签集合
# cv Kfold的几折 3 5 10 
lr = LogisticRegression()
result = cross_val_score(lr,good_X,y,cv=5)
result
array([0.72151899, 0.79746835, 0.65822785, 0.70886076, 0.6835443 ])
lr = LogisticRegression()
result = cross_val_score(lr,good_X,y,cv=3)
result
array([0.71212121, 0.71212121, 0.57251908])
result.mean()
0.7139240506329114
result.std()
0.04709132971579312
# LogisticRegression()
# C :惩罚项的力度 C越小 对错误容忍度越低 越容易过拟合
# penalty: l1 l2
# max_iter :梯度下降的迭代次数限制

# C = [0.01, 0.1,1, 10,20]
# penalty = ['l1','l2']
# max_iter  = [50, 100, 200, 300,500,1000]
# eg: 使用上面的例子 
param_grid = {
    
    
    'C':[0.01, 0.1,1, 10,20],
    'penalty':['l1','l2'],
    'max_iter':[50, 100, 200, 300,500,1000]
}
param_grid
{'C': [0.01, 0.1, 1, 10, 20],
 'penalty': ['l1', 'l2'],
 'max_iter': [50, 100, 200, 300, 500, 1000]}
gscv = GridSearchCV(estimator=LogisticRegression(),param_grid=param_grid,cv=10)
gscv.fit(good_X,y)

猜你喜欢

转载自blog.csdn.net/qq_33489955/article/details/124186543
今日推荐