Java 发送数据到Tensorflow服务端 提速调用速度(Grpc方式)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/boom_man/article/details/86223341

官方:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto

原作者在构建TensorProto对象时放入的是List<Float> 而官方推荐的是放入TensorContent,其速度差距2倍

官方描述是这样的

 Serialized raw tensor content from either Tensor::AsProtoTensorContent or
 memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
can be used for all tensor types. The purpose of this representation is to
reduce serialization overhead during RPC call by avoiding serialization of
many repeated small items.

代码:

       //构造shape对象
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        //#150528 = 224 * 224 * 3
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(mat.height()));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(mat.width()));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));


        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_UINT8);
        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

        //图片的所有值

        tensorProtoBuilder.setTensorContent(ByteString.copyFrom(OpenCVUtils.mat2Content(mat)));
        
        ManagedChannel channel = ManagedChannelBuilder.forAddress(TENSOR_FLOW_URL, TENSOR_FLOW_PORT).useTransportSecurity().usePlaintext().build();


        PredictionServiceGrpc.PredictionServiceFutureStub predictionServiceFutureStub = PredictionServiceGrpc.newFutureStub(channel);


        //创建请求
        Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
        //模型名称和模型方法名预设
        Model.ModelSpec.Builder modelSpace = Model.ModelSpec.newBuilder();
        modelSpace.setName("ssd_hand");
        modelSpace.setSignatureName("serving_default");
        request.setModelSpec(modelSpace);
        //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法

        //将数据放到Request中
        request.putInputs("input", tensorProtoBuilder.build());


        ListenableFuture<Predict.PredictResponse> predict = predictionServiceFutureStub.predict(request.build());
         try {

                        long t = System.currentTimeMillis();


            Predict.PredictResponse response =  predict.get(50000, TimeUnit.MILLISECONDS);
            System.out.println("cost time: " + (System.currentTimeMillis() - t));
       } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        } catch (TimeoutException e) {
            e.printStackTrace();
        }
//mat 转字节数组
    public static byte[] mat2Content(Mat mat) {
        byte[] grayData = new byte[mat.cols() * mat.rows()*3];
        mat.get(0, 0, grayData);
        return grayData;
    }

maven

  <dependency>
            <groupId>com.yesup.oss</groupId>
            <artifactId>tensorflow-client</artifactId>
            <version>1.4-2</version>
            <exclusions>
                <exclusion>
                    <artifactId>slf4j-log4j12</artifactId>
                    <groupId>org.slf4j</groupId>
                </exclusion>
                <exclusion>
                    <groupId>io.grpc</groupId>
                    <artifactId>grpc-protobuf</artifactId>
                </exclusion>
                <exclusion>
                    <artifactId>grpc-stub</artifactId>
                    <groupId>io.grpc</groupId>
                </exclusion>
            </exclusions>
        </dependency>
        <!-- 这个库是做图像处理的 -->
        <dependency>
            <groupId>net.coobird</groupId>
            <artifactId>thumbnailator</artifactId>
            <version>0.4.8</version>
        </dependency>

        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-all</artifactId>
            <version>1.17.1</version>
            <exclusions>
                <exclusion>
                    <artifactId>protobuf-java</artifactId>
                    <groupId>com.google.protobuf</groupId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-tcnative-boringssl-static</artifactId>
            <version>2.0.7.Final</version>
        </dependency>

csdn:https://blog.csdn.net/shin627077/article/details/78592729

官方:https://github.com/tensorflow/serving

参考:小米云 http://docs.api.xiaomi.com/cloud-ml/modelservice/0903_use_java_client.html

总结:

Grpc调用核心是发送Tensor.proto

syntax = "proto3";

package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensorflow/core/framework/resource_handle.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
  DataType dtype = 1;

  // Shape of the tensor.  TODO(touts): sort out the 0-rank issues.
  TensorShapeProto tensor_shape = 2;

  // Only one of the representations below is set, one of "tensor_contents" and
  // the "xxx_val" attributes.  We are not using oneof because as oneofs cannot
  // contain repeated fields it would require another extra set of messages.

  // Version number.
  //
  // In version 0, if the "repeated xxx" representations contain only one
  // element, that element is repeated to fill the shape.  This makes it easy
  // to represent a constant Tensor with a single value.
  int32 version_number = 3;

  // Serialized raw tensor content from either Tensor::AsProtoTensorContent or
  // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
  // can be used for all tensor types. The purpose of this representation is to
  // reduce serialization overhead during RPC call by avoiding serialization of
  // many repeated small items.
  bytes tensor_content = 4;

  // Type specific representations that make it easy to create tensor protos in
  // all languages.  Only the representation corresponding to "dtype" can
  // be set.  The values hold the flattened representation of the tensor in
  // row major order.

  // DT_HALF. Note that since protobuf has no int16 type, we'll have some
  // pointless zero padding for each value here.
  repeated int32 half_val = 13 [packed = true];

  // DT_FLOAT.
  repeated float float_val = 5 [packed = true];

  // DT_DOUBLE.
  repeated double double_val = 6 [packed = true];

  // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
  repeated int32 int_val = 7 [packed = true];

  // DT_STRING
  repeated bytes string_val = 8;

  // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
  // and imaginary parts of i-th single precision complex.
  repeated float scomplex_val = 9 [packed = true];

  // DT_INT64
  repeated int64 int64_val = 10 [packed = true];

  // DT_BOOL
  repeated bool bool_val = 11 [packed = true];

  // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
  // and imaginary parts of i-th double precision complex.
  repeated double dcomplex_val = 12 [packed = true];

  // DT_RESOURCE
  repeated ResourceHandleProto resource_handle_val = 14;

  // DT_VARIANT
  repeated VariantTensorDataProto variant_val = 15;
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
  // Name of the type of objects being serialized.
  string type_name = 1;
  // Portions of the object that are not Tensors.
  bytes metadata = 2;
  // Tensors contained within objects being serialized.
  repeated TensorProto tensors = 3;
}

猜你喜欢

转载自blog.csdn.net/boom_man/article/details/86223341