一、libsvm下载和配置
下载网址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/
把文件解压到Python文件夹中的Lib\site-packages中,然后import测试一下是否安装完成
from libsvm.python.svmutil import *
from libsvm.python.svm import *
二、MNIST数据集的导入
从网站上下载的MNIST数据集的格式和libsvm要求的格式不同,因此需要把格式转化成libsvm的格式。数据导入和转换的代码如下:
import os
import struct
import numpy as np
from libsvm.python.commonutil import svm_read_problem
def load_data():
images_path = "mnist/train-images.idx3-ubyte"
labels_path = "mnist/train-labels.idx1-ubyte"
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
def load_data_test():
images_path = "mnist/t10k-images.idx3-ubyte"
labels_path = "mnist/t10k-labels.idx1-ubyte"
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
def data_to_libsvm(images, labels):
path = "mnist/images_libsvm_form"
try:
file1 = open(path)
file1.close()
print("libsvm格式的mnist训练集文件已生成,读取数据中。。。")
return svm_read_problem(path)
except IOError:
print("libsvm格式的训练集文件未生成,开始生成数据。")
with open(path, "w") as svmfile:
for i in range(len(images)):
svmfile.write(str(labels[i]) + " ")
for j in range(len(images[i])):
if (images[i][j] != 0):
svmfile.write(
str(j + 1) + ":" + "%.7f" % (images[i][j] / 255) +
" ")
svmfile.write("\n")
return svm_read_problem(path)
def testdata_to_libsvm(images, labels):
path = "mnist/test_images_libsvm_form"
try:
file1 = open(path)
file1.close()
print("libsvm格式的mnist测试集文件已生成,读取数据中。。。")
return svm_read_problem(path)
except IOError:
print("libsvm格式的测试集文件未生成,开始生成数据。")
with open(path, "w") as svmfile:
for i in range(len(images)):
svmfile.write(str(labels[i]) + " ")
for j in range(len(images[i])):
if (images[i][j] != 0):
svmfile.write(
str(j + 1) + ":" + "%.7f" % (images[i][j] / 255) +
" ")
svmfile.write("\n")
return svm_read_problem(path)
三、训练数据集
svm类:svm_class.py
from libsvm.python.svmutil import *
from libsvm.python.svm import *
class Svm:
def __init__(self, svm_label, svm_images, svm_test_label, svm_test_images):
self.svm_label = svm_label
self.svm_images = svm_images
self.svm_test_label = svm_test_label
self.svm_test_images = svm_test_images
def train(self, numToClassfy, numToTrain, args):
m = svm_train(self.svm_label[:numToTrain],
self.svm_images[:numToTrain], args)
p_label, p_acc, p_val = svm_predict(
self.svm_test_label[:numToClassfy],
self.svm_test_images[:numToClassfy], m)
return p_label
测试代码:
numToTrain = 60000 #训练数据集大小
numToClassfy = 10000 #测试数据集大小
print("开始读取MNIST数据:")
images, label = load_data()
test_images, test_label = load_data_test()
#libsvm格式数据读取
svm_label, svm_images = data_to_libsvm(images, label)
svm_test_label, svm_test_images = testdata_to_libsvm(test_images, test_label)
#第一种核函数svm模型训练
print("SVM 1训练中...")
svm = Svm(svm_label, svm_images, svm_test_label, svm_test_images)
svm.train(numToClassfy, numToTrain, "-q -m 1000")
#第二种核函数svm模型训练
print("SVM 2训练中...")
svm = Svm(svm_label, svm_images, svm_test_label, svm_test_images)
svm.train(numToClassfy, numToTrain, "-q -m 1000 -t 3")