版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
digits=load_digits()
x_train=digits.data[:1348]#训练图片
y_train=digits.target[:1348]#训练图片对应的数字
x_test=digits.data[1348:]#测试图片,不包含起始图片
y_test=digits.target[1348:]#测试图片对应的数字
ss=StandardScaler()#标准化数据
x_train=ss.fit_transform(x_train)
x_test=ss.transform(x_test)
lsvc=LinearSVC()
lsvc.fit(x_train,y_train)#训练
y_predict=lsvc.predict(x_test)#预测
while 1:
test=int(input("输入预测第几张图片;"))
test_1=test+1348
print("对{}图片进行训练,预测第{}图片,预测数字为{}".format(len(x_train),test,y_predict[test]))
if y_predict[test]==y_test[test]:
print("预测正确")
else:
print('错误')
plt.imshow(digits.images[test_1],cmap=plt.cm.gray_r,interpolation='nearest')
plt.show()