简单demo带你将tensorflow2.x的自定义keras模型转为tflite格式并部署到安卓端

环境:

windows  10

CUDA  10.1

cudnn  7.6.4

tensorflow-gpu  2.1

androidstudio  3.6

基本都是目前比较新的环境。

因为tensorflow2.0后,我特别喜欢使用keras自定义模型,所以我想找一种方法来保存模型用来部署。pb格式的模型我还没太看懂它保存的具体是哪一个方法,还是说所有方法都保存了,所以我暂时不考虑使用pb模型部署。tflite就简单多了,只会保存call方法下的流程。转换过程还是相当简单的,直接进入主题。

首先是python部分:

1.自定义一个多输入多输出的简单模型。

class test_model2(tf.keras.Model):
    def __init__(self, name="test_model2"):
        super(test_model2, self).__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(filters=1, kernel_size=2, kernel_initializer=tf.ones, name=self.name + "/conv1")

    @tf.function
    def call(self, inputs):
        output1 = self.conv1(inputs[0])
        output1 = tf.squeeze(output1)
        output1 = tf.reshape(output1, (1,))
        output2 = self.conv1(inputs[1])
        output2 = tf.squeeze(output2)
        output2 = tf.reshape(output2, (1,))
        return output1, output2
model = test_model2()
test_input1 = tf.ones((1, 2, 2, 1))
test_input2 = tf.zeros((1, 2, 2, 1))
input_list = [test_input1, test_input2]
test_output1, test_output2 = model(input_list)
print(test_output1)
print(test_output2)

运行后会打印

tf.Tensor([4.], shape=(1,), dtype=float32)
tf.Tensor([0.], shape=(1,), dtype=float32)

这是一个相当简单的模型。

2.下面是转换模型为tflite格式:

如果是使用自定义的训练循环而不是使用fit()函数,就需要手动设置模型的输入大小。这个例子中,我们视为自定义训练流程,设置一次输入大小就可以了。输入的值可以随机,主要是shape相符。因为默认转换的是call函数,所以如果训练与测试不是通过一个函数,推荐将训练函数不要与call同名。

test_input1 = tf.ones((1, 2, 2, 1))
test_input2 = tf.zeros((1, 2, 2, 1))
input_list = [test_input1, test_input2]
model._set_inputs(input_list)

最后转换并保存模型:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("./save/converted_model.tflite", "wb").write(tflite_model)

这是就可以在save文件夹中找到保存的tflite文件了。

之后是AndroidStudio部分:

1.新建一个工程

2.修改build.gradle,添加以下内容

android {
    ...
    defaultConfig {
        ...
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
        ...
    }
    aaptOptions {
        noCompress "tflite"
    }
    ...
}
dependencies {
    ...
    implementation 'org.tensorflow:tensorflow-lite:2.1.0'
}

3.放入tflite文件并读取

在app\src\main里创建一个assets文件夹,并将converted_model.tflite放入。这个文件路径不是唯一,放到assets里面只是为了读取方便而已。

读取代码如下所示:

String MODEL_FILE = "converted_model.tflite";
Interpreter tfLite = null;
try {
    tfLite = new Interpreter(loadModelFile(getAssets(), MODEL_FILE));
}catch(IOException e){
    e.printStackTrace();
}

其中 loadModelFile 函数为:

MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
            throws IOException {
        AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
        FileInputStream inputStream = new                 
        FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

4.建立输入tensor

tflite的输入是ByteBuffer格式:

int net_input_sz = 2;

ByteBuffer inputData1;
inputData1 = ByteBuffer.allocateDirect(net_input_sz * net_input_sz * 4);//4表示一个浮点占4byte
inputData1.order(ByteOrder.nativeOrder());
inputData1.rewind();
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);

ByteBuffer inputData2;
inputData2 = ByteBuffer.allocateDirect(net_input_sz * net_input_sz * 4);//4表示一个浮点占4byte
inputData2.order(ByteOrder.nativeOrder());
inputData2.rewind();
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);

Object[] inputArray = {inputData1, inputData2};
 

ByteBuffer开辟的空间大小就是网络输入大小的总和乘精度所占byte,比如本例中设定的输入shape是1x2x2x1所以它是4,浮点数占4byte,所以是4x4的大小。

5.构建输出tensor

float[] output1, output2;
output1 = new float[1];
output2 = new float[1];
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, output1);
outputMap.put(1, output2);

我们例子中网络的输出shape是[1,],所以这里直接构建一个1大小的浮点数组就可以了。如果有二维甚至三维的输出,比如[2,3,4],则需要构建多维数组new float[2][3][4]。不过我不喜欢构建多维数组的方法,因为不方便传到Native层做处理,所以我一般都会把输出做reshape(output_tensor,[-1])。

6.执行推理并打印输出

tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Log.e("1111","output1:" + output1[0]);
Log.e("1111","output2:" + output2[0]);

一句就可以解决。可以得到如下的输出

2020-02-25 16:27:55.569 22585-22585/com.stars.tflite_test E/1111: output1:4.0
2020-02-25 16:27:55.569 22585-22585/com.stars.tflite_test E/1111: output2:0.0

至此就完成了整个模型的转换和部署。

嘿嘿,是不是相当简单。

附上我的java部分的包名,哪些是必须的我忘了,就全加上来吧:

package com.stars.tflite_test;

import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.os.Bundle;

import com.google.android.material.floatingactionbutton.FloatingActionButton;
import com.google.android.material.snackbar.Snackbar;

import androidx.appcompat.app.AppCompatActivity;
import androidx.appcompat.widget.Toolbar;

import android.util.Log;
import android.view.View;
import android.view.Menu;
import android.view.MenuItem;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;

遇到的坑:

1.官方的BN层我没办法转到tflite中(不知道是什么原因),所以我自己从官方的BN层魔改了一个没有那么多功能的BN层,成功转换tflite。

2.转换后的tflite模型在GPU上运行得到的结果不正确,不知道怎么解决,我已经在github上提了问题了,但是还没有得到答案。问题地址https://github.com/tensorflow/tensorflow/issues/38825,有懂的也可以和我说说,我试了用session模式写的模型在GPU上能够得到正确的结果,但是用2.x的模式写的模型就不行。

猜你喜欢

转载自blog.csdn.net/qq_19313495/article/details/104498442