使用sklearn.linear_model.SGDClassifier增量模型进行学习的记录

数据集下载链接是Human Activity Recognition Using Smartphones

train、test文件夹中分别包含训练和测试的文件,这里使用train中的数据进行增量学习模型,test中的数据用来测试
首先读取数据:

import numpy as np
from sklearn.linear_model import SGDClassifier
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

x_train=[]
f=open(r'./X_train.txt','r')
for i in f.readlines():
    i=i.strip()
    temp=i.split(' ')
    while '' in temp:
        temp.remove('')
    x_train.append(temp)
f.close()
x_train=np.array(x_train)

x_test=[]
f=open(r'./X_test.txt','r')
for i in f.readlines():
    i=i.strip()
    temp=i.split(' ')
    while '' in temp:
        temp.remove('')
    x_test.append(temp)
f.close()
x_test=np.array(x_test)

y_train=[]
f = open(r'./y_train.txt', 'r')
for i in f.readlines():
    i = i.strip()
    y_train.append(i)
f.close()

y_test=[]
f = open(r'./y_test.txt', 'r')
for i in f.readlines():
    i = i.strip()
    y_test.append(i)
f.close()

print(x_train.shape,end=' ')
print(x_test.shape,end=' ')
print(len(y_train),len(y_test),set(y_train+y_test))

输出结果为:

(7352, 561) (2947, 561) 7352 2947 {'4', '5', '3', '1', '2', '6'}

开始增量训练:

x_train=x_train.astype(np.float32)
x_test=x_test.astype(np.float32)

classes=np.unique(y_train+y_test)

interval=100
start=0
sgd_clf = SGDClassifier()
x_axis=[]
y_axis=[]
for i in np.arange(1,(x_train.shape[0]//interval+1),1):
    end=min([i*interval,x_train.shape[0]])
    X=x_train[start:end]
    Y=y_train[start:end]
    sgd_clf.partial_fit(X,Y,classes=classes) #
    start=end
    # print("{} time".format(i))  # 当前次数
    score=sgd_clf.score(x_test, y_test)
    # print("{} score".format(score))  # 在测试集上看效果
    x_axis.append(i)
    y_axis.append(score)

0.3505259586019681 score
0.49338310145911096 score
0.4832032575500509 score
0.497794367153037 score
0.46623685103495077 score
0.4696301323379708 score
...

绘制迭代次数-score图:

plt.figure()
plt.plot(x_axis,y_axis)
plt.xlabel('迭代的次数')
plt.ylabel('score')
plt.tight_layout()
plt.savefig('./score.png',bbox_inches='tight')
plt.show()


参考内容:
使用sklearn进行增量学习
http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html

猜你喜欢

转载自blog.csdn.net/shiheyingzhe/article/details/82316616
今日推荐