cnn手写体识别
1. 基本介绍
- 手写体识别,是指对图像进行识别,判断图像中的内容是否为手写文字。
- 本项目是一手写数字识别为主,采用的模型是cnn。
1.1 步骤
- 数据集:MNIST手写数字数据集,该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28,共10个类别。
- python的框架是pytorch,使用pytorch的框架进行训练和测试。
- 识别准确率为,98%
- 模型转化:将pytorch的模型转化为onnx格式,方便在安卓端使用。
- 以java的代码推理模型,在安卓端或者其他环境中实现手写数字识别。
1.2 项目结构
.
├── DNS_tunnel_detect
│ ├── DNS_tunnel_detect.iml
│ ├── README.md
│ ├── bin
│ ├── lib
│ ├── out
│ ├── source
│ └── src
├── cnn_py
│ ├── data
│ ├── main.py
│ └── model
├── model2onnx
│ ├── model
│ ├── model2onnx.py
│ └── test_onnx_model.py
└── 第3集: java落地AI项目案例:cnn手写字体识别.md
1.3 模型结构
第一层包含卷积、批量归一化、ReLU激活和最大池化操作;
第二层结构相同但输出通道数为32;
全连接层将前一层输出扁平化后接分类器。
import torch
import torch.nn as nn
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
2.训练
model = ConvNet(num_classes).to(device)
print(model)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
print(images.size())
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
3.测试模型
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), './model/model.ckpt')
4. 模型转化
4.1 模型转化
import os
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
device = torch.device("cpu")
num_classes = 10
model = ConvNet(num_classes).to(device)
print(model)
model.load_state_dict(torch.load('../cnn_py/model/model.ckpt',map_location=device))
sample_input = torch.rand((1,1,28,28)).to(device)
print(sample_input)
model.eval()
with torch.no_grad():
outputs = model(sample_input)
print("output:",outputs)
_, predicted = torch.max(outputs.data, 1)
print("predicted:",predicted)
torch.onnx.export(model,
sample_input,
'./model/model.onnx',
input_names=["input"],
output_names=["output"],
export_params=True, # 是否保存模型参数
do_constant_folding=True) # 是否执行常量折叠优化
torch.cuda.empty_cache()
4.2 pytorch模型转化为onnx模型
import os
import warnings
warnings.filterwarnings('ignore')
import onnxruntime
import torch
input_data = torch.rand(1,1,28,28)
session = onnxruntime.InferenceSession("./model/model.onnx")
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
print("result: ",result)
print(result[0][0])
max_value = max(list(result[0][0]))
predict = list(result[0][0]).index(max_value)
print(predict)
5. java端使用onnx模型进行预测
- 需要安装onnxruntime库
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class App {
public static void main(String[] args) throws Exception {
String model_path = "./source/model.onnx";
System.out.println(model_path);
float[][][][] feature = new float[1][1][28][28];
// 初始化数组元素
for (int i = 0; i < 1; i++) {
for (int j = 0; j < 1; j++) {
for (int k = 0; k < 28; k++) {
for (int l = 0; l < 28; l++) {
feature[i][j][k][l] = (i + 1) * (j + 1) * (k + 1) * (l + 1);
}
}
}
}
System.out.println(Arrays.toString(feature));
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.Result res = null;
try (OrtSession session = env.createSession(model_path)){
Map<String, OnnxTensor> container = new HashMap<>();
OnnxTensor inputTensor = OnnxTensor.createTensor(env, feature);
container.put("input", inputTensor);
try(OrtSession.Result result = session.run(container)){
OnnxTensor outputTensor = (OnnxTensor) result.get(0);
float[][] result88 = (float[][])outputTensor.getValue();
System.out.println(Arrays.toString(result88));
for (int i = 0; i < result88.length; i++) {
for (int j = 0; j < result88[i].length; j++) {
System.out.println(result88[i][j]);
}
}
}
OnnxValue.close(container);
}catch (OrtException e) {
throw new RuntimeException(e);
} finally {
System.out.println("all done");
}
}
}
6.总结
- 完成手写字体的python脚本训练和测试
- 完成onnx模型转化
- 完成java端使用onnx模型进行预测