python使用libsvm

一、libsvm的安装与测试

1. 安装libsvm

官方网站:LIBSVM -- A Library for Support Vector Machines (ntu.edu.tw)

pip install -U libsvm-official

2. 测试libsvm 

测试libsvm是否可以使用:

注意包含包的时候不能直接写

from svmutil import *

而要写

from libsvm.svmutil import *

代码如下:

#! /usr/bin/env python3.10
# encoding: utf-8
import sys
from libsvm.svmutil import *

path='C:/Users/30296/Documents/libsvm-3.31/python' # 测试代码:(注意libsvm-3.31/python路径)
sys.path.append(path)
train_label,train_pixel = svm_read_problem('C:/Users/30296/Documents/libsvm-3.31/heart_scale')
model = svm_train(train_label[:200],train_pixel[:200],'-c 4')
print("result:")
p_label, p_acc, p_val = svm_predict(train_label[200:], train_pixel[200:], model);
print(p_acc)

得到如下结果:

*.*
optimization finished, #iter = 257
nu = 0.351161
obj = -225.628984, rho = 0.636110
nSV = 91, nBSV = 49
Total nSV = 91
result:
Accuracy = 84.2857% (59/70) (classification)
(84.28571428571429, 0.6285714285714286, 0.463744141163496)

二、使用libsvm

1.数据预处理:用python转换txt

libsvm的数据需要是:

label         1:value1            2:value2              3:value3 …
 1           1:1           2:2                    3:3
 -1          1:1           2:2                    3:3

而我们的数据格式一般是:

value1         value2         value3        …         label
 1              2          3                  1
 1              2          3                 -1

需要提前转换一下:

"""
original data form:
value1 value2 value3 value4 ......label
target data form:
label index1:value1 index2:value2 ......
 
"""
import numpy as np

def readFromTxt(filename):
    # o_data = np.loadtxt(filename,encoding='utf-8')
    o_data = np.genfromtxt(filename,delimiter=',')#,dtype=[float, float, float, float, float, float, int]   
    # print(o_data)
    lines = o_data.shape[0]
    columns = o_data.shape[1]
    print(lines,columns)
    new_data = np.zeros([lines, 2*columns-1], dtype=list)
    i = 0
    while i < lines:
        j = 1
        new_data[i][0] = o_data[i][columns-1]
        while j < columns:
            new_data[i][2*j-1] = j
            new_data[i][2*j] = o_data[i][j-1]
            j = j+1
        i = i+1
    return new_data,lines,2*columns-1
 
# 替换成自己的文件名
filename = "collision_results_-1.900000.txt"
newfilename = "new-collision_results_-1.900000.txt"


new_data,newlines,newcolumns = readFromTxt(filename)

f = open(newfilename,"w+")
i = 0
while i < newlines:
    # if classification --> label shall be int
    f.write(str(int(new_data[i][0])))

    # if regression  --> label shall be real
    #  f.write(str(new_data[i][0]))
    f.write(" ")
    j = 1
    while j < newcolumns:
        f.write(str(new_data[i][j]))
        if j % 2 == 1:
            f.write(":")
        else:
            f.write(" ")
        j = j+1
    f.write("\n")
    i = i+1
f.close()

2. 学习训练

#! /usr/bin/env python3.10
# encoding: utf-8
import sys
from libsvm.svmutil import *
import numpy as np

# path='C:/Users/30296/Documents/libsvm-3.31/python' # 测试代码:(注意libsvm-3.31/python路径)
# sys.path.append(path)


train_label,train_pixel = svm_read_problem('C:/Users/30296/Documents/libsvm-3.31/heart_scale')
per_80=int(np.floor(0.8*len(train_label)))
print(per_80)
print(train_label[:per_80])

model = svm_train(train_label[:per_80],train_pixel[:per_80],'-c 4')
print("result:")
p_label, p_acc, p_val = svm_predict(train_label[per_80:], train_pixel[per_80:], model)

print(p_label,'下一个', p_acc,'下一个', p_val)

得到输出:

216
[1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0]
*.*
optimization finished, #iter = 236
nu = 0.350163
obj = -245.091761, rho = 0.725105
nSV = 96, nBSV = 57
Total nSV = 96
result:
Accuracy = 83.3333% (45/54) (classification)
[-1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0] 下一个 (83.33333333333334, 0.6666666666666666, 0.4376577840112202) 下一个 [[-1.6124757728862074], [0.0886991748765581], [0.4663204240189208], [-1.809307063660131], [1.789713658958068], [1.4263019450015681], [-1.3657682586413893], [1.9107841719202647], [-1.396542713524683], [-0.1876978331664505], [0.15542694054439632], [0.7089826475868414], [-1.4615037457023727], [-1.130433148293036], [0.3205752603012352], [0.415261195248798], [-0.976906868341664], [0.9029986716445745], [1.956591639825957], [0.7847881477664195], [-1.1472189675661113], [1.076648795606011], [-0.8582718654160085], [-1.8068212282248122], [0.6840602464863493], [-1.101835886542706], [-1.2402235520592604], [-0.5678741375735478], [-0.8835802514321], [1.0834647106286988], [1.5292489100848137], [-1.4540902207078812], [-1.313175746493155], [2.6895483406734924], [1.1512582225370713], [-1.639392753733577], [0.11443306053238411], [-1.0278685255180595], [-1.2251953985069255], [-1.033632513651953], [0.4120354239632241], [1.9451408840796356], [-1.315608704047866], [-0.9240953214882549], [-2.108813048050079], [2.5980107155518404], [-1.732737578667451], [-1.71260757809163], [-0.5738742753188778], [-1.0641364858922406], [-0.8518389292541279], [-1.8430304080540507], [-0.001770943242147327], [2.450684766756411]]

猜你喜欢

转载自blog.csdn.net/weixin_45226065/article/details/131561808