SVM是一种分类器,下面通过手写0-9数字识别对其进行以下介绍。
1.首先准备训练使用的手写字体
如图所示,将手写字体分类放在不同的文件夹。
2.读取图片
//每种数字个数
const int count[10] = {5923,6742,5958,6131,5842,5421,5918,6265,5851,5949};
string filename = "shouxieziti/";
vector<Mat> imgin;
vector<int> number;
int sum = 0;
for(int i = 0; i < 10; i++){
string s;
stringstream ss;
ss<<i;
ss>>s;
for(int j = 1; j < count[i]+1; j++){
string s1;
stringstream ss1;
ss1<<j;
ss1>>s1;
if(j<10){
s1 = s+"_0000"+s1;
}else if(j < 100){
s1 = s+"_000"+s1;
}else if(j < 1000){
s1 = s+"_00"+s1;
}else{
s1 = s+"_0"+s1;
}
string in = filename + s + "/" + s1 +".jpg";
Mat img = imread(in,IMREAD_GRAYSCALE);
// imshow(in,img);
imgin.push_back(img);
number.push_back(i);
cout<<in<<" ok"<<" "<<img.channels()<<" "<<number[sum + j - 1]<<" "<<sum<<endl;
}
sum += count[i];
}
cout<<imgin.size()<<" "<<imgin[0].size()<<"have been read"<<endl;
图片信息的读取由自己的存储方式进行。
3.生成opencv中SVM需要的形式
Mat imgtrain((int)imgin.size(), 28*28, CV_32FC1);
Mat imglabel((int)imgin.size(), 1, CV_32SC1);
// cout<<imgtrain.channels()<<" "<<imglabel.channels()<<endl;
cout<<"creat train data..."<<endl;
for(int i = 0; i < (int)imgin.size(); i++){
Mat_<float>::iterator trainbegin = imgtrain.begin<float>() + 28*28*i;
Mat_<int>::iterator labelbegin = imglabel.begin<int>();
Mat_<uchar>::iterator inbegin = imgin[i].begin<uchar>();
for(int j = 0; j < 28*28; j++){
float data = (float)*(inbegin+j);
*(trainbegin+j) = (data+0.0)/255.0;
// if(data > 200){
// cout<<*(trainbegin+j)<<" "<<*(labelbegin+j);
// }
}
*(labelbegin+i) = number[i];
cout<<*(labelbegin+i)<<" ";
}
其中训练数据是CV_32FC1类型;label数据是CV_32SC1类型。
另外,需要将数据进行归一化,因为读取的是灰度图0-255范围之内,所以我们将每个数据除以255就可以得到0-1之间的数据。
4.利用SVM进行训练
//设置SVM参数
Ptr<ml::SVM> svm = ml::SVM::create();
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::RBF);
svm->setGamma(0.01);
svm->setC(10.0);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_ITER, 1000,FLT_EPSILON));
//进行训练
cout<<"trainning..."<<endl;
bool f = svm->train(imgtrain,ml::ROW_SAMPLE,imglabel);
// Ptr<ml::TrainData> traindata = ml::TrainData::create(imgtrain,ml::ROW_SAMPLE,imglabel);
// bool f = svm->trainAuto(traindata, 10);
// cout<<f<<endl;
//保存训练好的数据
cout<<"saving..."<<endl;
svm->save("train1.xml");
cout<<"save done..."<<endl;
5.读取生成的train1.xml进行预测
Ptr<ml::SVM> svm = ml::StatModel::load<ml::SVM>("train1.xml");
cout<<"predicting..."<<endl;
vector<float> result;
int right = 0, wrong = 0;
Mat_<int>::iterator labelbegin = imglabel.begin<int>();
for(int i = 0; i < (int)imgtrain.rows; i++){
Mat sample = imgtrain.row(i);
result.push_back(svm->predict(sample));
cout<<result[i]<<endl;
if(abs(result[i] - *(labelbegin+i)) < 0.001){
right++;
}else{
wrong++;
}
}
cout<<"predict done... "<<right<<" right "<<wrong<<" wrong"<<endl;
cout<<"right rate "<<(float)right/(float)(right+wrong)<<endl;
cout<<"wrong rate "<<(float)wrong/(float)(right+wrong)<<endl;
6.通过训练60000个样本,能实现非常高的正确率。下图是识别了10000个测试数据的结果
7.补充
学习过程中主要参考了如下链接:
https://www.cnblogs.com/cheermyang/p/5624333.html
手写字体是由mnist手写字体图像数据库生成的,参考下列链接:
http://m.blog.csdn.net/fengbingchun/article/details/49611549