声明:内容非原创,代码来自葁sir
import numpy as np
import pandas as pd
from pandas import Series, DataFrame
from sklearn. model_selection import train_test_split
from sklearn. neighbors import KNeighborsClassifier
from sklearn. tree import DecisionTreeClassifier
from sklearn. linear_model import LogisticRegression
samples = pd. read_csv( 'data/student-data.csv' )
samples. head( )
school
sex
age
address
famsize
Pstatus
Medu
Fedu
Mjob
Fjob
...
higher
internet
romantic
famrel
freetime
goout
Dalc
Walc
health
passed
0
GP
F
18
U
GT3
A
4
4
at_home
teacher
...
yes
no
no
4
3
4
1
1
3
no
1
GP
F
17
U
GT3
T
1
1
at_home
other
...
yes
yes
no
5
3
3
1
1
3
no
2
GP
F
15
U
LE3
T
1
1
at_home
other
...
yes
yes
no
4
3
2
2
3
3
yes
3
GP
F
15
U
GT3
T
4
2
health
services
...
yes
yes
yes
3
2
2
1
1
5
yes
4
GP
F
16
U
GT3
T
3
3
other
other
...
yes
no
no
4
3
2
1
2
5
yes
5 rows × 30 columns
samples. dtypes
school object
sex object
age int64
address object
famsize object
Pstatus object
Medu int64
Fedu int64
Mjob object
Fjob object
reason object
guardian object
traveltime int64
studytime int64
failures int64
schoolsup object
famsup object
paid object
activities object
nursery object
higher object
internet object
romantic object
famrel int64
freetime int64
goout int64
Dalc int64
Walc int64
health int64
passed object
dtype: object
samples. info( )
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 395 entries, 0 to 394
Data columns (total 30 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 school 395 non-null object
1 sex 395 non-null object
2 age 395 non-null int64
3 address 395 non-null object
4 famsize 395 non-null object
5 Pstatus 395 non-null object
6 Medu 395 non-null int64
7 Fedu 395 non-null int64
8 Mjob 395 non-null object
9 Fjob 395 non-null object
10 reason 395 non-null object
11 guardian 395 non-null object
12 traveltime 395 non-null int64
13 studytime 395 non-null int64
14 failures 395 non-null int64
15 schoolsup 395 non-null object
16 famsup 395 non-null object
17 paid 395 non-null object
18 activities 395 non-null object
19 nursery 395 non-null object
20 higher 395 non-null object
21 internet 395 non-null object
22 romantic 395 non-null object
23 famrel 395 non-null int64
24 freetime 395 non-null int64
25 goout 395 non-null int64
26 Dalc 395 non-null int64
27 Walc 395 non-null int64
28 health 395 non-null int64
29 passed 395 non-null object
dtypes: int64(12), object(18)
memory usage: 92.7+ KB
samples. columns. tolist( )
['school',
'sex',
'age',
'address',
'famsize',
'Pstatus',
'Medu',
'Fedu',
'Mjob',
'Fjob',
'reason',
'guardian',
'traveltime',
'studytime',
'failures',
'schoolsup',
'famsup',
'paid',
'activities',
'nursery',
'higher',
'internet',
'romantic',
'famrel',
'freetime',
'goout',
'Dalc',
'Walc',
'health',
'passed']
samples. describe( ) . columns. tolist( )
['age',
'Medu',
'Fedu',
'traveltime',
'studytime',
'failures',
'famrel',
'freetime',
'goout',
'Dalc',
'Walc',
'health']
samples. describe( [ 0.01 , 0.99 ] ) . T
count
mean
std
min
1%
50%
99%
max
age
395.0
16.696203
1.276043
15.0
15.0
17.0
20.0
22.0
Medu
395.0
2.749367
1.094735
0.0
1.0
3.0
4.0
4.0
Fedu
395.0
2.521519
1.088201
0.0
1.0
2.0
4.0
4.0
traveltime
395.0
1.448101
0.697505
1.0
1.0
1.0
4.0
4.0
studytime
395.0
2.035443
0.839240
1.0
1.0
2.0
4.0
4.0
failures
395.0
0.334177
0.743651
0.0
0.0
0.0
3.0
3.0
famrel
395.0
3.944304
0.896659
1.0
1.0
4.0
5.0
5.0
freetime
395.0
3.235443
0.998862
1.0
1.0
3.0
5.0
5.0
goout
395.0
3.108861
1.113278
1.0
1.0
3.0
5.0
5.0
Dalc
395.0
1.481013
0.890741
1.0
1.0
1.0
5.0
5.0
Walc
395.0
2.291139
1.287897
1.0
1.0
2.0
5.0
5.0
health
395.0
3.554430
1.390303
1.0
1.0
4.0
5.0
5.0
import seaborn as sns
import matplotlib. pyplot as plt
sns. set ( )
plt. figure( figsize= ( 12 , 4 ) )
samples. notnull( ) . mean( ) . plot( kind= 'bar' )
<AxesSubplot:>
data = [ '男' , '女' , '男' ]
map_dict = {
'男' : 1 , '女' : 0 }
Series( data) . map ( map_dict)
0 1
1 0
2 1
dtype: int64
f = lambda x: ( x== '男' ) * 1
Series( data) . map ( f)
0 1
1 0
2 1
dtype: int64
from sklearn. preprocessing import LabelEncoder, OneHotEncoder
OneHotEncoder( ) . fit_transform( LabelEncoder( ) . fit_transform( data) . reshape( - 1 , 1 ) ) . toarray( )
array([[0., 1.],
[1., 0.],
[0., 1.]])
m_data = OneHotEncoder( ) . fit_transform( LabelEncoder( ) . fit_transform( data) . reshape( - 1 , 1 ) ) . toarray( )
X= samples. iloc[ : , : - 1 ] . copy( )
y = samples. iloc[ : , - 1 ] . copy( )
for col_name in X:
data = X[ col_name]
if data. dtype == 'object' :
print ( col_name, '---->' , data. unique( ) )
school ----> ['GP' 'MS']
sex ----> ['F' 'M']
address ----> ['U' 'R']
famsize ----> ['GT3' 'LE3']
Pstatus ----> ['A' 'T']
Mjob ----> ['at_home' 'health' 'other' 'services' 'teacher']
Fjob ----> ['teacher' 'other' 'services' 'health' 'at_home']
reason ----> ['course' 'other' 'home' 'reputation']
guardian ----> ['mother' 'father' 'other']
schoolsup ----> ['yes' 'no']
famsup ----> ['no' 'yes']
paid ----> ['no' 'yes']
activities ----> ['no' 'yes']
nursery ----> ['yes' 'no']
higher ----> ['yes' 'no']
internet ----> ['no' 'yes']
romantic ----> ['no' 'yes']
for col_name in X:
data = X[ col_name]
if data. dtype == 'object' :
X[ col_name] = LabelEncoder( ) . fit_transform( data)
X
school
sex
age
address
famsize
Pstatus
Medu
Fedu
Mjob
Fjob
...
nursery
higher
internet
romantic
famrel
freetime
goout
Dalc
Walc
health
0
0
0
18
1
0
0
4
4
0
4
...
1
1
0
0
4
3
4
1
1
3
1
0
0
17
1
0
1
1
1
0
2
...
0
1
1
0
5
3
3
1
1
3
2
0
0
15
1
1
1
1
1
0
2
...
1
1
1
0
4
3
2
2
3
3
3
0
0
15
1
0
1
4
2
1
3
...
1
1
1
1
3
2
2
1
1
5
4
0
0
16
1
0
1
3
3
2
2
...
1
1
0
0
4
3
2
1
2
5
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
390
1
1
20
1
1
0
2
2
3
3
...
1
1
0
0
5
5
4
4
5
4
391
1
1
17
1
1
1
3
1
3
3
...
0
1
1
0
2
4
5
3
4
2
392
1
1
21
0
0
1
1
1
2
2
...
0
1
0
0
5
5
3
3
3
3
393
1
1
18
0
1
1
3
2
3
2
...
0
1
1
0
4
4
1
3
4
5
394
1
1
19
1
1
1
1
1
2
0
...
1
1
1
0
3
2
3
3
3
5
395 rows × 29 columns
from sklearn. preprocessing import StandardScaler
ss_X = StandardScaler( ) . fit_transform( X)
ss_X = DataFrame( data= ss_X, columns= X. columns)
ss_X
school
sex
age
address
famsize
Pstatus
Medu
Fedu
Mjob
Fjob
...
nursery
higher
internet
romantic
famrel
freetime
goout
Dalc
Walc
health
0
-0.363050
-0.948176
1.023046
0.535392
-0.636941
-2.938392
1.143856
1.360371
-1.769793
1.993149
...
0.507899
0.23094
-2.232677
-0.708450
0.062194
-0.236010
0.801479
-0.540699
-1.003789
-0.399289
1
-0.363050
-0.948176
0.238380
0.535392
-0.636941
0.340322
-1.600009
-1.399970
-1.769793
-0.325831
...
-1.968894
0.23094
0.447893
-0.708450
1.178860
-0.236010
-0.097908
-0.540699
-1.003789
-0.399289
2
-0.363050
-0.948176
-1.330954
0.535392
1.570004
0.340322
-1.600009
-1.399970
-1.769793
-0.325831
...
0.507899
0.23094
0.447893
-0.708450
0.062194
-0.236010
-0.997295
0.583385
0.551100
-0.399289
3
-0.363050
-0.948176
-1.330954
0.535392
-0.636941
0.340322
1.143856
-0.479857
-0.954077
0.833659
...
0.507899
0.23094
0.447893
1.411533
-1.054472
-1.238419
-0.997295
-0.540699
-1.003789
1.041070
4
-0.363050
-0.948176
-0.546287
0.535392
-0.636941
0.340322
0.229234
0.440257
-0.138362
-0.325831
...
0.507899
0.23094
-2.232677
-0.708450
0.062194
-0.236010
-0.997295
-0.540699
-0.226345
1.041070
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
390
2.754443
1.054656
2.592380
0.535392
1.570004
-2.938392
-0.685387
-0.479857
0.677354
0.833659
...
0.507899
0.23094
-2.232677
-0.708450
1.178860
1.768808
0.801479
2.831553
2.105989
0.320890
391
2.754443
1.054656
0.238380
0.535392
1.570004
0.340322
0.229234
-1.399970
0.677354
0.833659
...
-1.968894
0.23094
0.447893
-0.708450
-2.171138
0.766399
1.700867
1.707469
1.328545
-1.119469
392
2.754443
1.054656
3.377047
-1.867789
-0.636941
0.340322
-1.600009
-1.399970
-0.138362
-0.325831
...
-1.968894
0.23094
-2.232677
-0.708450
1.178860
1.768808
-0.097908
1.707469
0.551100
-0.399289
393
2.754443
1.054656
1.023046
-1.867789
1.570004
0.340322
0.229234
-0.479857
0.677354
-0.325831
...
-1.968894
0.23094
0.447893
-0.708450
0.062194
0.766399
-1.896683
1.707469
1.328545
1.041070
394
2.754443
1.054656
1.807713
0.535392
1.570004
0.340322
-1.600009
-1.399970
-0.138362
-2.644812
...
0.507899
0.23094
0.447893
-0.708450
-1.054472
-1.238419
-0.097908
1.707469
0.551100
1.041070
395 rows × 29 columns
ss_X. describe( [ 0.01 , 0.99 ] ) . T
count
mean
std
min
1%
50%
99%
max
school
395.0
1.866299e-16
1.001268
-0.363050
-0.363050
-0.363050
2.754443
2.754443
sex
395.0
-4.834389e-17
1.001268
-0.948176
-0.948176
-0.948176
1.054656
1.054656
age
395.0
1.411529e-15
1.001268
-1.330954
-1.330954
0.238380
2.592380
4.161713
address
395.0
6.998621e-17
1.001268
-1.867789
-1.867789
0.535392
0.535392
0.535392
famsize
395.0
1.281675e-16
1.001268
-0.636941
-0.636941
-0.636941
1.570004
1.570004
Pstatus
395.0
-1.503720e-16
1.001268
-2.938392
-2.938392
0.340322
0.340322
0.340322
Medu
395.0
8.432074e-18
1.001268
-2.514630
-1.600009
0.229234
1.143856
1.143856
Fedu
395.0
-1.264811e-16
1.001268
-2.320084
-1.399970
-0.479857
1.360371
1.360371
Mjob
395.0
-1.158707e-16
1.001268
-1.769793
-1.769793
-0.138362
1.493069
1.493069
Fjob
395.0
-1.607715e-16
1.001268
-2.644812
-2.644812
-0.325831
1.993149
1.993149
reason
395.0
4.384678e-17
1.001268
-1.040599
-1.040599
-0.211896
1.445509
1.445509
guardian
395.0
3.091760e-17
1.001268
-1.591714
-1.591714
0.273945
2.139603
2.139603
traveltime
395.0
-2.203582e-16
1.001268
-0.643249
-0.643249
-0.643249
3.663251
3.663251
studytime
395.0
-2.709506e-16
1.001268
-1.235351
-1.235351
-0.042286
2.343844
2.343844
failures
395.0
-2.599889e-16
1.001268
-0.449944
-0.449944
-0.449944
3.589323
3.589323
schoolsup
395.0
-1.577360e-15
1.001268
-0.385040
-0.385040
-0.385040
2.597133
2.597133
famsup
395.0
8.207218e-17
1.001268
-1.257656
-1.257656
0.795130
0.795130
0.795130
paid
395.0
3.766326e-17
1.001268
-0.919671
-0.919671
-0.919671
1.087346
1.087346
activities
395.0
5.846238e-17
1.001268
-1.017881
-1.017881
0.982433
0.982433
0.982433
nursery
395.0
-8.966105e-17
1.001268
-1.968894
-1.968894
0.507899
0.507899
0.507899
higher
395.0
8.923945e-17
1.001268
-4.330127
-4.330127
0.230940
0.230940
0.230940
internet
395.0
1.767925e-16
1.001268
-2.232677
-2.232677
0.447893
0.447893
0.447893
romantic
395.0
-1.062441e-16
1.001268
-0.708450
-0.708450
-0.708450
1.411533
1.411533
famrel
395.0
-1.410967e-16
1.001268
-3.287804
-3.287804
0.062194
1.178860
1.178860
freetime
395.0
1.028713e-16
1.001268
-2.240828
-2.240828
-0.236010
1.768808
1.768808
goout
395.0
-2.062345e-17
1.001268
-1.896683
-1.896683
-0.097908
1.700867
1.700867
Dalc
395.0
8.769357e-17
1.001268
-0.540699
-0.540699
-0.540699
3.955638
3.955638
Walc
395.0
-3.091760e-17
1.001268
-1.003789
-1.003789
-0.226345
2.105989
2.105989
health
395.0
1.169248e-16
1.001268
-1.839649
-1.839649
0.320890
1.041070
1.041070
X_train, X_test, y_train, y_test = train_test_split( ss_X, y, test_size= 0.2 , random_state= 1 )
knn = KNeighborsClassifier( )
knn. fit( X_train, y_train)
knn. score( X_test, y_test)
0.6962025316455697
lr = LogisticRegression( )
lr. fit( X_train, y_train)
lr. score( X_test, y_test)
0.7088607594936709
lr. coef_
array([[ 0.03520344, 0.22443684, -0.36334047, 0.03790903, 0.23429034,
-0.14052023, 0.1929704 , 0.0183496 , -0.24143 , 0.09296691,
0.19759787, -0.17086501, -0.05668839, 0.28625693, -0.52088123,
-0.29399977, -0.33164972, 0.12570786, -0.17226849, -0.10907161,
0.35411433, 0.07470418, -0.19945335, 0.11104188, 0.16833114,
-0.41107393, -0.05262516, 0.29150008, -0.02790423]])
plt. figure( figsize= ( 12 , 4 ) )
np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) . sort_values( ascending= False ) . plot( kind= 'bar' )
<AxesSubplot:>
np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) . mean( )
0.190936289299829
lr = LogisticRegression( penalty= 'l1' , solver= 'liblinear' )
lr. fit( X_train, y_train)
LogisticRegression(penalty='l1', solver='liblinear')
lr. score( X_test, y_test)
0.7088607594936709
condition = np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) > np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) . mean( )
np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) [ condition] . index
Index(['sex', 'age', 'famsize', 'Mjob', 'reason', 'studytime', 'failures',
'schoolsup', 'famsup', 'higher', 'romantic', 'goout', 'Walc'],
dtype='object')
lr = LogisticRegression( )
lr. fit( X_train, y_train)
lr. score( X_test, y_test)
0.7088607594936709
condition = np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) > np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) . mean( )
np. abs ( Series( data= lr. coef_[ 0 ] , index= X. columns) ) [ condition] . index
Index(['sex', 'age', 'famsize', 'Medu', 'Mjob', 'reason', 'studytime',
'failures', 'schoolsup', 'famsup', 'higher', 'romantic', 'goout',
'Walc'],
dtype='object')
l1_index = np. array( [ 'sex' , 'age' , 'famsize' , 'Mjob' , 'reason' , 'studytime' , 'failures' ,
'schoolsup' , 'famsup' , 'higher' , 'romantic' , 'goout' , 'Walc' ] )
l2_index = np. array( [ 'sex' , 'age' , 'famsize' , 'Medu' , 'Mjob' , 'reason' , 'studytime' ,
'failures' , 'schoolsup' , 'famsup' , 'higher' , 'romantic' , 'goout' ,
'Walc' ] )
importance_feature = list ( set ( l1_index) & set ( l2_index) )
importance_feature
['schoolsup',
'goout',
'studytime',
'sex',
'famsup',
'age',
'Mjob',
'famsize',
'failures',
'romantic',
'higher',
'reason',
'Walc']
good_X = X[ importance_feature]
good_X
schoolsup
goout
studytime
sex
famsup
age
Mjob
famsize
failures
romantic
higher
reason
Walc
0
1
4
2
0
0
18
0
0
0
0
1
0
1
1
0
3
2
0
1
17
0
0
0
0
1
0
1
2
1
2
2
0
0
15
0
1
3
0
1
2
3
3
0
2
3
0
1
15
1
0
0
1
1
1
1
4
0
2
2
0
1
16
2
0
0
0
1
1
2
...
...
...
...
...
...
...
...
...
...
...
...
...
...
390
0
4
2
1
1
20
3
1
2
0
1
0
5
391
0
5
1
1
0
17
3
1
0
0
1
0
4
392
0
3
1
1
0
21
2
0
3
0
1
0
3
393
0
1
1
1
0
18
3
1
0
0
1
0
4
394
0
3
1
1
0
19
2
1
0
0
1
0
3
395 rows × 13 columns
X_train, X_test, y_train, y_test = train_test_split( good_X, y, test_size= 0.2 , random_state= 1 )
lr = LogisticRegression( )
lr. fit( X_train, y_train)
D:\software\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
LogisticRegression()
lr. score( X_test, y_test)
0.7341772151898734
算法调参
from sklearn. model_selection import KFold
from sklearn. model_selection import cross_val_score
from sklearn. model_selection import GridSearchCV
kfold = KFold( n_splits= 3 )
kfold. split( good_X, y)
<generator object _BaseKFold.split at 0x000001BF8CD4EE40>
generator = kfold. split( good_X, y)
for g in generator:
print ( g)
(array([132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235,
236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261,
262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287,
288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300,
301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313,
314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326,
327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352,
353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391,
392, 393, 394]), array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
130, 131]))
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
130, 131, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287,
288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300,
301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313,
314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326,
327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352,
353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391,
392, 393, 394]), array([132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235,
236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261,
262, 263]))
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181,
182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194,
195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220,
221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233,
234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246,
247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259,
260, 261, 262, 263]), array([264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276,
277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289,
290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302,
303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328,
329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341,
342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354,
355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380,
381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393,
394]))
X. shape
(395, 29)
lr = LogisticRegression( )
generator = kfold. split( good_X, y)
temp = None
for g in generator:
X_train = good_X. loc[ g[ 0 ] ]
X_test = good_X. loc[ g[ 1 ] ]
y_train = y[ g[ 0 ] ]
y_test = y[ g[ 1 ] ]
lr. fit( X_train, y_train)
print ( lr. score( X_test, y_test) )
0.7954545454545454
0.6439393939393939
0.5801526717557252
lr = LogisticRegression( )
result = cross_val_score( lr, good_X, y, cv= 5 )
result
array([0.72151899, 0.79746835, 0.65822785, 0.70886076, 0.6835443 ])
lr = LogisticRegression( )
result = cross_val_score( lr, good_X, y, cv= 3 )
result
array([0.71212121, 0.71212121, 0.57251908])
result. mean( )
0.7139240506329114
result. std( )
0.04709132971579312
param_grid = {
'C' : [ 0.01 , 0.1 , 1 , 10 , 20 ] ,
'penalty' : [ 'l1' , 'l2' ] ,
'max_iter' : [ 50 , 100 , 200 , 300 , 500 , 1000 ]
}
param_grid
{'C': [0.01, 0.1, 1, 10, 20],
'penalty': ['l1', 'l2'],
'max_iter': [50, 100, 200, 300, 500, 1000]}
gscv = GridSearchCV( estimator= LogisticRegression( ) , param_grid= param_grid, cv= 10 )
gscv. fit( good_X, y)