【libsvm】Python下使用libsvm识别数据集mnist

一、libsvm下载和配置

下载网址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/

把文件解压到Python文件夹中的Lib\site-packages中,然后import测试一下是否安装完成

from libsvm.python.svmutil import *

from libsvm.python.svm import *

二、MNIST数据集的导入

  从网站上下载的MNIST数据集的格式和libsvm要求的格式不同,因此需要把格式转化成libsvm的格式。数据导入和转换的代码如下:

import os
import struct
import numpy as np
from libsvm.python.commonutil import svm_read_problem

def load_data():
    images_path = "mnist/train-images.idx3-ubyte"
    labels_path = "mnist/train-labels.idx1-ubyte"
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels


def load_data_test():
    images_path = "mnist/t10k-images.idx3-ubyte"
    labels_path = "mnist/t10k-labels.idx1-ubyte"
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels



def data_to_libsvm(images, labels):
    path = "mnist/images_libsvm_form"
    try:
        file1 = open(path)  
        file1.close()
        print("libsvm格式的mnist训练集文件已生成,读取数据中。。。")
        return svm_read_problem(path)

    except IOError:
        print("libsvm格式的训练集文件未生成,开始生成数据。")
        with open(path, "w") as svmfile:
            for i in range(len(images)):
                svmfile.write(str(labels[i]) + " ")
                for j in range(len(images[i])):
                    if (images[i][j] != 0):
                        svmfile.write(
                            str(j + 1) + ":" + "%.7f" % (images[i][j] / 255) +
                            " ")
                svmfile.write("\n")
        return svm_read_problem(path)


def testdata_to_libsvm(images, labels):
    path = "mnist/test_images_libsvm_form"
    try:
        file1 = open(path)
        file1.close()
        print("libsvm格式的mnist测试集文件已生成,读取数据中。。。")
        return svm_read_problem(path)

    except IOError:
        print("libsvm格式的测试集文件未生成,开始生成数据。")
        with open(path, "w") as svmfile:
            for i in range(len(images)):
                svmfile.write(str(labels[i]) + " ")
                for j in range(len(images[i])):
                    if (images[i][j] != 0):
                        svmfile.write(
                            str(j + 1) + ":" + "%.7f" % (images[i][j] / 255) +
                            " ")
                svmfile.write("\n")
        return svm_read_problem(path)

三、训练数据集

svm类:svm_class.py

from libsvm.python.svmutil import *
from libsvm.python.svm import *


class Svm:
    def __init__(self, svm_label, svm_images, svm_test_label, svm_test_images):
        self.svm_label = svm_label
        self.svm_images = svm_images
        self.svm_test_label = svm_test_label
        self.svm_test_images = svm_test_images

    def train(self, numToClassfy, numToTrain, args):
        m = svm_train(self.svm_label[:numToTrain],
                      self.svm_images[:numToTrain], args)
        p_label, p_acc, p_val = svm_predict(
            self.svm_test_label[:numToClassfy],
            self.svm_test_images[:numToClassfy], m)
        return p_label

测试代码:


numToTrain = 60000  #训练数据集大小
numToClassfy = 10000 #测试数据集大小

print("开始读取MNIST数据:")
images, label = load_data()
test_images, test_label = load_data_test()
#libsvm格式数据读取
svm_label, svm_images = data_to_libsvm(images, label)
svm_test_label, svm_test_images = testdata_to_libsvm(test_images, test_label)

#第一种核函数svm模型训练
print("SVM 1训练中...")
svm = Svm(svm_label, svm_images, svm_test_label, svm_test_images)
svm.train(numToClassfy, numToTrain, "-q -m 1000")

#第二种核函数svm模型训练
print("SVM 2训练中...")
svm = Svm(svm_label, svm_images, svm_test_label, svm_test_images)
svm.train(numToClassfy, numToTrain, "-q -m 1000 -t 3")

完整代码:https://github.com/Dakod/Mnist_Learn

猜你喜欢

转载自blog.csdn.net/darord/article/details/88795471
今日推荐