Learning record - a small example of Pytorch model transplantation to Android

Reminder: 注意文章时效性, 2022.04.02.


foreword

Recently, I was working on porting the image classification model to Android. I originally planned to use Tensorflow to do it, but some blog posts from Baidu are a bit old, 17 or 18 years old, and then I found the official implementation example of Tensorflow and found the first example . It has been deprecated and replaced . But the README in this new example didn’t say how to deal with the model. Service Unavailable often appears on the Tensorflow official website , and the result of the model I implemented with Tensorflow is very strange. Pytorch can find a newer example:

Resolutely abandon Tensorflow, switch to Pytorch, refer to the official example operation, the model can still run out.
Here is a brief record of the implementation process and some errors encountered .
The nonsense ends and the text begins.


Zero, the use of the environment

environment of use Version
Train the model:
Python 3.7.3
Pytorch 1.11.0
Export the model:
Python 3.9
Pytorch 1.9.0
Android deployment:
Android Studio 4.1.1
pytorch_android_lite 1.9.0
pytorch_android_torchvision 1.9.0

If there is an error like this:

No toolchains found in the NDK toolchains folder for ABI with prefix: arm-linux-androideabi

It may be a problem with NDK. NDK is not installed or ND is installed but K lacks the corresponding library. You can refer to this blog post to install ( perfect solution No toolchains found in the NDK toolchains folder for ABI with prefix: mips64el-linux-android_CodeForCoffee's blog- CSDN Blog ). However, the URL for downloading NDK cannot be accessed, you can download it here ( AndroidDevTools - Android Development Tools Android SDK Download Android Studio Download Gradle Download SDK Tools Download )

1. Model preparation

1. Export the model

According to the referenced blog posts and official tutorials , you need to export your own model. I also tried the method in the blog post, but in the end I managed to run it myself, and changed it from the official example, as follows:

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from model_v3 import mobilenet_v3_large  # 导入自己的模型

model_pth = './MobileNetV3-20220330-01.pth'  # 训练得到的模型参数文件的路径
mobile_ptl = './mobilenetV3large.ptl'  # 模型保存为Android可以调用的文件的路径
model = mobilenet_v3_large(num_classes=7)  # 实例化模型
pre_weights = torch.load(model_pth, map_location='cpu')  # 读取参数
model.load_state_dict(pre_weights, strict=True)  # 将参数载入到模型
device = torch.device('cpu')  # 将torch.Tensor分配到的设备的对象,有cpu和cuda两种
model.to(device)  # 将模型加载到指定设备上
model.eval()  # 将模型设为验证模式
example = torch.rand(1, 3, 224, 224)  # 输入样例的格式为一张224*224的3通道图像
# 上面是准备模型,下面就是转换了
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(mobile_ptl)

The model used in Pytorch's official example is a pre-trained MobileNetV2, imported into torchvision, and then called.

……
import torchvision
……
model = torchvision.models.mobilenet_v2(pretrained=True)
……

2. Error record

2.1 To load the complete model (network structure + weight parameters)

If only parameters are loaded, an error will be reported;

AttributeError: 'collections.OrderedDict' object has no attribute 'eval' ……

Only loading the model network to train the model is equivalent to no training, and the model has no parameters.
So when saving model files, there are generally two different ways:

  1. Only the parameters of the model are saved (the official recommendation is this). If only the weight parameters are saved during training, the model weight parameters should be put in the model when loading.
# Save:
torch.save(model.state_dict(), PATH)
# Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
  1. Save the entire model (network structure + parameters)
# Save:
torch.save(model, PATH)
# Load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

More specific instructions can be found in the official documentation ( SAVING AND LOADING MODELS )

2.2 Exported model file format

Although those reference blog posts all say that they are to be exported as files, but when I run the file .ptexported by loading the complete model, an error will be reported:.pt

java.lang.RuntimeException: Unable to start activity ComponentInfo{
    
    org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: PytorchStreamReader failed locating file bytecode.pkl: file not found ()
    Exception raised from valid at ../caffe2/serialize/inline_container.cc:157 (most recent call first):
    (no backtrace available)

According to the official example, .ptlit can be successfully run after exporting to a file.

2. Android deployment

This part refers to the Android deployment part of this blog post ( How to deploy the pytorch model to Android , the implementation is similar to the official example), although it was written with reference to this blog post at the beginning, it did not run successfully.
Let's refer to the steps of the big brother and go again.

1. Create a new project

Create a new one directly Empty Activityand click Next .
New Project

2. Fill in the project information

Give it a name and call myModelit, keep the others as default, and click Finish .
Fill in the project information

3. Import package (add dependency library)

Import the package of pytorch_android_lite ( different from pytorch_android , the method of loading the model is different).

//Pytorch
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

Add dependent library
Complete build.gradle(:app)as follows:

plugins {
    
    
    id 'com.android.application'
}

android {
    
    
    compileSdkVersion 30
    buildToolsVersion "30.0.3"

    defaultConfig {
    
    
        applicationId "com.test.mymodel"
        minSdkVersion 23
        targetSdkVersion 30
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    }

    buildTypes {
    
    
        release {
    
    
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }
    compileOptions {
    
    
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}

dependencies {
    
    

    implementation 'androidx.appcompat:appcompat:1.2.0'
    implementation 'com.google.android.material:material:1.2.1'
    implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
    testImplementation 'junit:junit:4.+'
    androidTestImplementation 'androidx.test.ext:junit:1.1.2'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
    //Pytorch
    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}

Note: If the Pytorch version used by the exported model is different from the version of the pytorch_andorid_lite package used by the Android project, an error will be reported.

java.lang.RuntimeException: Unable to start activity ComponentInfo{
    
    org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: Lite Interpreter verson number does not match. The model version must be between 3 and 5But the model version is 7 ()
    Exception raised from parseMethods at ../torch/csrc/jit/mobile/import.cpp:320 (most recent call first):
    (no backtrace available)

The Pytorch version I used to train the model is 1.11.0, and the above error will occur when running with this version, and it 1.9.0can be run by replacing it with the same version on Android.

4. Page layout

One is TextViewused to display text results, and the other ImageViewis used to display pictures.
Page Layout

The full activity_main.xmlfile is as follows:

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">

    <TextView
        android:id="@+id/tv"
        android:layout_weight="1"
        android:layout_width="match_parent"
        android:layout_height="0dp"
        android:layout_margin="10dp"
        android:layout_gravity="center"
        android:text="Hello World!"
        android:textSize="50sp"
        android:textAlignment="center"
        android:textStyle="bold"/>

    <ImageView
        android:id="@+id/iv"
        android:layout_weight="4"
        android:layout_width="match_parent"
        android:layout_height="0dp"
        android:layout_margin="10dp"
        android:background="#f0f0f0"
        android:layout_gravity="center"
        android:contentDescription="@string/iv_text" />

</LinearLayout>

5. Add result categories

Create a new EmotionClasses.javaclass file, here is the expression classification, there are seven categories, put them in the order of the training labels. (If the order is wrong, the result will be misplaced)
Create a new class for storing result categories

package com.test.mymodel;

public class EmotionClasses {
    
    
    public static String[] EMOTION_CLASSES = new String[]{
    
    
            "anger",
            "disgust",
            "fear",
            "happy",
            "normal",
            "sad",
            "surprised"
    };
}

6. Add model files and pictures

Create a new folder under mainthe folder assets, and put the model .ptlfile and the picture to be recognized into it. ( The picture needs to be the size set when exporting the modelexample , here is a 224*224 color picture)
Put in model files and test images

7. Call the model

After MainActivity.javaloading the model, recognize the image.

package com.test.mymodel;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {
    
    

    @Override
    protected void onCreate(Bundle savedInstanceState) {
    
    
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        Bitmap bitmap = null;
        Module module = null;
        try {
    
    
            // creating bitmap from packaged into app android asset 'image.jpg',
            // app/src/main/assets/image.jpg
            bitmap = BitmapFactory.decodeStream(getAssets().open("happy01.jpg"));
            // loading serialized torchscript module from packaged into app android asset model.pt,
            // app/src/model/assets/model.pt
            module = LiteModuleLoader.load(assetFilePath(this, "mobilenetV3large.ptl"));
        } catch (IOException e) {
    
    
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }

        // showing image on UI
        ImageView imageView = findViewById(R.id.iv);
        imageView.setImageBitmap(bitmap);

        // preparing input tensor
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

        // running the model
        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

        // getting tensor content as java array of floats
        final float[] scores = outputTensor.getDataAsFloatArray();

        // searching for the index with maximum score
        float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < scores.length; i++) {
    
    
            if (scores[i] > maxScore) {
    
    
                maxScore = scores[i];
                maxScoreIdx = i;
            }
        }

        String className = EmotionClasses.EMOTION_CLASSES[maxScoreIdx];

        // showing className on UI
        TextView textView = findViewById(R.id.tv);
        textView.setText(className);
    }

    /**
     * Copies specified asset to the file in /files app directory and returns this file absolute path.
     *
     * @return absolute file path
     */
    public static String assetFilePath(Context context, String assetName) throws IOException {
    
    
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
    
    
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
    
    
            try (OutputStream os = new FileOutputStream(file)) {
    
    
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
    
    
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}

Note : If you use pytorch_android_litea dependent library, but use Module.load()the method to load the model, an error will be reported, indicating that libpytorch_jni.sothe library cannot be found, and you need to use LiteModuleLoader.load()the method to load the model. Someone mentioned in the official issue couldn't find "libpytorch_jni.so" .

java.lang.UnsatisfiedLinkError: dlopen failed: library "libpytorch_jni.so" not found

8. Running results

The running results are as follows:
operation result
If the order of the categories is misplaced, the recognition result will also be misplaced, as shown in the figure below, it will angerbe adjusted to the fourth place, and the recognition result will be anger.
Misplaced order of categories


3. Summary

  • In fact, if you just want to do a simple image classification, you can basically run it by changing the model and classification category in the HelloWorldApp of the official android-demo-app.
  • These frameworks are updated too quickly, resulting in limited timeliness of some articles. Once the version is changed, or even a certain method is changed, various errors will pop up. The solution is various searches.
  • Other blogs can be used as a reference. The main process still depends on the official tutorial. If you encounter problems, you can go to the project to issuefind out if there are any similar problems to yourself, and you may be able to get inspiration from them.
  • Finally, the small demo that ran out was thrown on Gitee , here .
    Summed up the loneliness.
    If it is useful, please give it a thumbs up.
    If you find something wrong, welcome to correct me.
    Friendly comments, peaceful communication.

Guess you like

Origin blog.csdn.net/weixin_44438341/article/details/123897165