java落地AI模型-cnn手写体识别

cnn手写体识别

1. 基本介绍

  1. 手写体识别,是指对图像进行识别,判断图像中的内容是否为手写文字。
  2. 本项目是一手写数字识别为主,采用的模型是cnn。

1.1 步骤

  1. 数据集:MNIST手写数字数据集,该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28,共10个类别。
  2. python的框架是pytorch,使用pytorch的框架进行训练和测试。
  3. 识别准确率为,98%
  4. 模型转化:将pytorch的模型转化为onnx格式,方便在安卓端使用。
  5. 以java的代码推理模型,在安卓端或者其他环境中实现手写数字识别。

1.2 项目结构

.
├── DNS_tunnel_detect
│   ├── DNS_tunnel_detect.iml
│   ├── README.md
│   ├── bin
│   ├── lib
│   ├── out
│   ├── source
│   └── src
├── cnn_py
│   ├── data
│   ├── main.py
│   └── model
├── model2onnx
│   ├── model
│   ├── model2onnx.py
│   └── test_onnx_model.py
└── 第3集: java落地AI项目案例:cnn手写字体识别.md

1.3 模型结构

第一层包含卷积、批量归一化、ReLU激活和最大池化操作;
第二层结构相同但输出通道数为32;
全连接层将前一层输出扁平化后接分类器。

import torch
import torch.nn as nn

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
    
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
    return out

2.训练

model = ConvNet(num_classes).to(device)
print(model)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        print(images.size())
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

3.测试模型

# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), './model/model.ckpt')

在这里插入图片描述

4. 模型转化

4.1 模型转化

import os
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

device = torch.device("cpu")
num_classes = 10
model = ConvNet(num_classes).to(device)
print(model)

model.load_state_dict(torch.load('../cnn_py/model/model.ckpt',map_location=device))

sample_input = torch.rand((1,1,28,28)).to(device)
print(sample_input)

model.eval()
with torch.no_grad():
    outputs = model(sample_input)
    print("output:",outputs)
    _, predicted = torch.max(outputs.data, 1)
    print("predicted:",predicted)
    
torch.onnx.export(model,
                  sample_input,
                  './model/model.onnx',
                  input_names=["input"],
                  output_names=["output"],
                  export_params=True,       # 是否保存模型参数
                  do_constant_folding=True)	# 是否执行常量折叠优化
    

torch.cuda.empty_cache()

在这里插入图片描述

4.2 pytorch模型转化为onnx模型

import os
import warnings
warnings.filterwarnings('ignore')

import onnxruntime
import torch

input_data = torch.rand(1,1,28,28)
session = onnxruntime.InferenceSession("./model/model.onnx")
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
print("result: ",result)
print(result[0][0])
max_value = max(list(result[0][0]))
predict = list(result[0][0]).index(max_value)
print(predict)

在这里插入图片描述

5. java端使用onnx模型进行预测

  • 需要安装onnxruntime库

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class App {
    public static void main(String[] args) throws Exception {
        String model_path = "./source/model.onnx";
        System.out.println(model_path);

        float[][][][] feature = new float[1][1][28][28];
        // 初始化数组元素
        for (int i = 0; i < 1; i++) {
            for (int j = 0; j < 1; j++) {
                for (int k = 0; k < 28; k++) {
                    for (int l = 0; l < 28; l++) {
                        feature[i][j][k][l] = (i + 1) * (j + 1) * (k + 1) * (l + 1);
                    }
                }
            }
        }
        System.out.println(Arrays.toString(feature));

        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.Result res = null;
        try (OrtSession session = env.createSession(model_path)){
            Map<String, OnnxTensor> container = new HashMap<>();

            OnnxTensor inputTensor = OnnxTensor.createTensor(env, feature);
            container.put("input", inputTensor);

            try(OrtSession.Result result = session.run(container)){
                OnnxTensor outputTensor = (OnnxTensor) result.get(0);
                float[][] result88 = (float[][])outputTensor.getValue();
                System.out.println(Arrays.toString(result88));
                for (int i = 0; i < result88.length; i++) {
                    for (int j = 0; j < result88[i].length; j++) {
                        System.out.println(result88[i][j]);
                    }
                }
            }
            OnnxValue.close(container);
        }catch (OrtException e) {
            throw new RuntimeException(e);
        } finally {
            System.out.println("all done");
        }
     }
}

6.总结

  1. 完成手写字体的python脚本训练和测试
  2. 完成onnx模型转化
  3. 完成java端使用onnx模型进行预测

猜你喜欢

转载自blog.csdn.net/weixin_32393347/article/details/142650680