使用sklearn进行鸢尾花分类预测 模型:LogisticRegression

1.加载数据集 导包

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

iris=load_iris()

2.切分数据集

x_train = iris.data
y_train = iris.target
#切分数据 test_size 测试集20% random_state随机种子可以随意指定 stratify 根据y分层
x_train,x_test,y_train,y_test = train_test_split(x_train,y_train,test_size=0.2,random_state=0,stratify=y_train)

3.建立模型

#logitic 回归的分类模型
lr = LogisticRegression()
lr.fit(x_train,y_train)

输出结果:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)

4.查看预测结果

result = lr.predict(x_test)
print('预测的结果',result)
print('实际的结果',y_test)

输出结果:

预测的结果 [0 1 0 2 0 1 2 0 0 1 2 1 1 2 1 2 2 1 1 0 0 2 2 1 0 1 1 2 0 0]
实际的结果 [0 1 0 2 0 1 2 0 0 1 2 1 1 2 1 2 2 1 1 0 0 2 2 2 0 1 1 2 0 0]

5.查看默认参数 以及模型评分

#默认参数
params=lr.get_params()
print(params)
#模型评分 准确率
s1 = lr.score(x_train,y_train)
s2 = lr.score(x_test,y_test)
print('在训练集上的准确度评分',s1)
print('在测试集上的准确度评分',s2)

6.预测为某个类别的概率

#预测为某个类型的概率
result = lr.predict_proba(x_test)

一行数据有三个值,对应三种花的概率,概率最大即预测为该类别

7.超参数搜索

#模型目标的参数
from sklearn.grid_search import GridSearchCV
penaltys=['l1','l2']#l1 或l2正则化
cs = [1.0,1.1,1.2,1.3,1.4,1.5]
param_grid = {'penalty':penaltys,'C':cs}
#print(param_grid)
gsc = GridSearchCV(LogisticRegression(),param_grid)
#print(x_train)
gsc.fit(x_train,y_train)

print('最佳模型参数的评分:',gsc.best_score_)
print('最优参数')
best_params = gsc.best_estimator_.get_params()
print(best_params)
for param_name in sorted(param_grid.keys()):
    print(param_name,':',best_params[param_name])

输出结果:

最佳模型参数的评分: 0.9583333333333334
最优参数
{'C': 1.5, 'class_weight': None, 'dual': False, 'fit_intercept': True, 'intercept_scaling': 1, 'max_iter': 100, 'multi_class': 'ovr', 'n_jobs': 1, 'penalty': 'l2', 'random_state': None, 'solver': 'liblinear', 'tol': 0.0001, 'verbose': 0, 'warm_start': False}
C : 1.5
penalty : l2

猜你喜欢

转载自blog.csdn.net/qq_33361080/article/details/82620208
今日推荐