机器学习实战——基于LightGBM的手写数字识别(附完整代码和结果图)
关于作者
作者:小白熊
作者简介:精通python、matlab、c#语言,擅长机器学习,深度学习,机器视觉,目标检测,图像分类,姿态识别,语义分割,路径规划,智能优化算法,数据分析,各类创新融合等等。
联系邮箱:[email protected]
科研辅导、知识付费答疑、个性化定制以及其他合作需求请联系作者~
1 引言
在本文中,我们将介绍如何利用LightGBM(Light Gradient Boosting Machine)进行手写数字识别任务。我们将使用scikit-learn
中的手写数字数据集,经过数据预处理、模型训练、评估和可视化,完整地展示整个流程。
2 数据集介绍
手写数字数据集(Digits Dataset)是一个经典的机器学习数据集,广泛用于图像分类和模式识别的研究。该数据集包含1797张8x8像素的灰度图像,表示数字0到9。每个图像都对应一个目标标签,表示图像中所包含的数字。数据集特点:
- 特征数量: 每张图像由64个像素点(8x8)组成,因此每个样本有64个特征(像素值)。
- 目标变量: 目标变量为0到9的整数,代表数字的类别。数据集中有10个不同的类别。
- 样本数量: 数据集总共有1797个样本,适合用于小规模实验和学习。
该数据集常用于机器学习模型的训练和测试,特别是在手写数字识别、图像处理和深度学习领域。
3 模型理论
LightGBM是由微软开发的高效、快速的梯度提升框架。与其他传统的梯度提升算法相比,LightGBM在处理大规模数据集和高维数据时具有更好的性能。LightGBM的主要特点:
- 高效性: 采用基于直方图的算法,将连续特征离散化,以减少内存占用和加速计算。
- 高并行性: 支持特征并行和数据并行,使其在多核处理器上表现优秀。
- 自动处理缺失值: LightGBM能够自动处理缺失值,而无需用户额外预处理。
- 支持类别特征: 可以直接处理类别特征,无需进行独热编码,简化了数据预处理的复杂性。
4 代码流程
4.1 导入库
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# 设置中文字体为SimHei
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
4.2 加载数据集
我们使用load_digits
函数加载手写数字数据集,并将其转换为Pandas的DataFrame
,以便于处理和分析。
# 加载数据集
data = load_digits()
# 数据转换
df = pd.DataFrame(data.data, columns=data.feature_names)
df['Target'] = data.target
# 检查数据的缺失情况
missing_values = df.isnull().sum()
print("缺失值检测结果:\n", missing_values)
4.3 数据预处理
我们将数据分为特征和目标变量,然后使用StandardScaler
对特征进行标准化处理,以确保每个特征都在相同的尺度范围内。
# 数据提取
X = df.drop('Target', axis=1)
y = df['Target']
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
4.4 划分训练集和验证集
接下来,我们将数据集划分为训练集和验证集,比例为80%和20%。这有助于评估模型在未见数据上的表现。
# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
4.5 构建模型
我们使用LightGBM构建分类模型,并通过十折交叉验证评估模型的表现。
# 构建模型
model = lgb.LGBMClassifier(verbosity=-1)
# 十折交叉验证
kf = KFold(n_splits=10, shuffle=True, random_state=42)
cv_scores = cross_val_score(model, X_train, y_train, cv=kf, scoring='accuracy')
print(f"平均准确率: {
np.mean(cv_scores)}")
4.6 模型训练和预测
我们使用训练集训练模型,然后在验证集上进行预测。
# 训练模型
model.fit(X_train, y_train)
# 模型预测
y_val_pred = model.predict(X_val)
4.7 模型评估
在模型预测完成后,我们计算准确率、精确率、召回率、F1分数等评估指标,以全面了解模型性能。
# 计算混淆矩阵
conf_matrix = confusion_matrix(y_val, y_val_pred)
# 计算分类评估指标
accuracy_val = accuracy_score(y_val, y_val_pred)
precision_val = precision_score(y_val, y_val_pred, average='weighted')
recall_val = recall_score(y_val, y_val_pred, average='weighted')
f1_val = f1_score(y_val, y_val_pred, average='weighted')
print(f"准确率: {
accuracy_val}")
print(f"精确率: {
precision_val}")
print(f"召回率: {
recall_val}")
print(f"F1分数: {
f1_val}")
4.8 可视化混淆矩阵
最后,我们通过混淆矩阵可视化模型的分类效果,直观展示模型在每个类别上的表现。
# 可视化混淆矩阵
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=np.arange(10), yticklabels=np.arange(10))
plt.title(f'混淆矩阵')
plt.xlabel("预测值")
plt.ylabel("真实值")
plt.show()
5 完整代码
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# 设置中文字体为SimHei
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 加载数据集
data = load_digits()
# 数据转换
df = pd.DataFrame(data.data, columns=data.feature_names)
df['Target'] = data.target
# 缺失值检测
missing_values = df.isnull().sum()
print("缺失值检测结果:\n", missing_values)
# 数据提取
X = df.drop('Target', axis=1)
y = df['Target']
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# 构建模型
model = lgb.LGBMClassifier(verbosity=-1)
# 十折交叉验证
kf = KFold(n_splits=10, shuffle=True, random_state=42)
cv_scores = cross_val_score(model, X_train, y_train, cv=kf, scoring='accuracy')
print(f"平均准确率: {
np.mean(cv_scores)}")
# 训练模型
model.fit(X_train, y_train)
# 模型预测
y_val_pred = model.predict(X_val)
# 计算混淆矩阵
conf_matrix = confusion_matrix(y_val, y_val_pred)
# 计算分类评估指标
accuracy_val = accuracy_score(y_val, y_val_pred)
precision_val = precision_score(y_val, y_val_pred, average='weighted')
recall_val = recall_score(y_val, y_val_pred, average='weighted')
f1_val = f1_score(y_val, y_val_pred, average='weighted')
print(f"准确率: {
accuracy_val}")
print(f"精确率: {
precision_val}")
print(f"召回率: {
recall_val}")
print(f"F1分数: {
f1_val}")
# 可视化混淆矩阵
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=data.target_names, yticklabels=data.target_names)
plt.title(f'混淆矩阵')
plt.xlabel("预测值")
plt.ylabel("真实值")
plt.show()
小结
本文展示了如何使用LightGBM进行手写数字识别的完整流程,包括数据加载、预处理、模型构建、训练、评估和可视化。LightGBM作为一种高效的机器学习框架,在处理大规模数据时表现出色,适合用于各类分类和回归问题。