opencv学习笔记六十一:Googlenet模型实现图像分类

DNN模块可以实现的功能:

图像分类

对象检测

实时对象检测

图像分割

预测(年龄,性别等)

视频对象跟踪

Googlenet 模型基于100万张图像实现1000个分类,首先下载模型bvlc_googlenet .caffemodel(二进制文件),模型下载地址为

http://dl.caffe.berkeleyvision.org/,然后找到描述文件bvlc_googlenet.prototxt和类别文件synset_words.txt,在opencv安装目录下E:\anzhuang\opencv3.4.1\opencv\sources\samples\data\dnn

方法步骤:

  1. 导入深度神经网络模块dnn(depth neural network)
  2. 读入类别文件
  3. 读入模型文件和描述文件
  4. 将图像转为网络输入的对象
  5. 前向传播得到分类结果
#include<opencv2\opencv.hpp>
using namespace cv;
using namespace dnn;
using namespace std;
String model_file = "bvlc_googlenet .caffemodel";
String model_txtfile = "bvlc_googlenet.prototxt";
String labels_file = "synset_words.txt";
vector<String>readLabels();

int main(int arc, char** argv) { 
	Mat src = imread("4.jpg");
	namedWindow("input", CV_WINDOW_AUTOSIZE);
	imshow("input", src);

	//读取模型的类别(文本)
	vector<String> labels = readLabels();

	//读取google_net的模型和描述文件
	Net net = readNetFromCaffe(model_txtfile,model_file);	
	if (net.empty()) {
		printf("read caffee model data failure\n");
		return -1;
	}
	//将图像转为google_net网络输入的对象,由描述文件可知,图像尺寸统一为224*224
	Mat inputBlob = blobFromImage(src, 1.0, Size(224, 224), Scalar(104, 117, 123));

	//进行前向传播,由描述文件可知,第一层用了10个卷积层,提取图像10种不同的特征
	Mat prob;
	for (int i = 0; i < 10; i++) {
		net.setInput(inputBlob, "data");
		prob = net.forward("prob");//最后一层的输出为“prob”
	}
	
	//输出
	//printf("m = %d,n = %d", prob.rows, prob.cols);//得到的概率值为1行1000列的
	Point classLoc;
	double classProb;
	minMaxLoc(prob, NULL, &classProb, NULL, &classLoc);
	printf("current image classification: %s,probablity %f\n", labels.at(classLoc.x).c_str(), classProb);
	putText(src, labels.at(classLoc.x), Point(20, 20), FONT_HERSHEY_COMPLEX, 1.0, Scalar(0, 0, 255), 2);
	imshow("output", src);
	waitKey(0);
	return 0;
}

//读取模型的类别(文本)
vector<String>readLabels() {
	vector<String>classNames;
	ifstream fp(labels_file);//打开文件
	if (!fp.is_open()) {//文件没打开
		printf("could not open the file ");
		exit(-1);
	}
	string name;
	while (!fp.eof()) {//文件没读到结尾
		getline(fp,name);//得到每一行,放到name中
		if (name.length()) {//非空行
			classNames.push_back(name.substr(name.find(' ') + 1));//
		}
	}
	fp.close();
	return classNames;
}

 

猜你喜欢

转载自blog.csdn.net/qq_24946843/article/details/82840791