自iphone X以来,各品牌的手机旗舰机型都开始支持Face ID。与之前的指纹识别相比,人脸识别还是更方便一些。近期在国内大火的抖音,也与人脸识别技术密不可分。各手机厂家还纷纷推出人工智能芯片,希望能让人工智能应用在手机、单片机上运行的更加流畅。可以看到,将在PC上已经大获成功的机器学习模型部署在移动设备上是大势所趋,又能引发一轮新的可能性。手机的精准定位带来了一大波O2O的应用,产生了包括美团、滴滴这样的巨型独角兽。那么人工智能在手机上的广泛应用又将带来什么变革呢?但这个问题不是本文讨论的重点,下面我来从技术角度出发,谈谈如何在Android上部署Tensorflow的图像识别模型。
Android中支持的Tensorflow模型文件为pb格式。pb相比于训练时的ckpt格式,会保存网络结构,并具有语言独立性,适合在移动端使用。
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
同时,由于在Android中读取pb文件时要通过变量名来获取输入层和输出层,所以在保存pb文件前,我分别将输入和输出层变量名设为“input”和“output”。
train_batch = tf.add(train_batch, tf.zeros([IMG_W, IMG_H, 3]), name="input")
output = tf.add(train_logits, tf.zeros(N_CLASSES), name="output")
App界面采用最常见的底部导航栏结构,使用BottomNavigationView方式实现。拥有从相册中选图片识别和摄像头实时识别两个Fragment,页面间的切换采用FragmentTransaction的show/hide方式。
fragmentManager = getSupportFragmentManager();
FragmentTransaction transaction = fragmentManager.beginTransaction();
transaction.add(R.id.fragment, mPhotoFragment = new PhotoFragment());
transaction.add(R.id.fragment, mCameraFragment = new CameraFragment());
transaction.commit();
对于摄像头Fragment,在onResume和onPause时分别Open、Release Camera;由于画面会不停移动需要重新对焦,还需要实现摄像头的自动对焦功能。
@Override
public void onResume() {
super.onResume();
onHiddenChanged(false);
}
@Override
public void onPause() {
onHiddenChanged(true);
super.onPause();
}
@Override
public void surfaceChanged(SurfaceHolder surfaceHolder, int i, int i1, int i2) {
camera.autoFocus(new Camera.AutoFocusCallback() {
@Override
public void onAutoFocus(boolean success, Camera camera) {
if(success){
focus();//实现相机的参数初始化
camera.cancelAutoFocus();
}
}
});
}
成功获取图片后,调用Classifier对图片进行识别。在Classifier中,将输入的bitmap图片转换成float数组,创建并调用TensorFlowInferenceInterface进行推断。根据识别结果对猫狗置信度的大小,在页面上展示猫或狗的头像。
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
inferenceInterface.run(outputNames, logStats);
inferenceInterface.fetch(outputName, outputs);
Classifier.Recognition answer = null;
for (final Classifier.Recognition recognition : results) {
if (answer == null || recognition.getConfidence() > answer.getConfidence()) {
answer = recognition;
}
}
int resId = answer.getTitle().equals("cat") ? R.mipmap.cat : R.mipmap.dog;
view.setImageResource(resId);
项目地址:https://github.com/wlkdb/TensorflowInAndroid