Getting started with TensorRT (3) official sample sampleOnnxMNIST

0. Preface

  • What this article mentioned is the sampleMNISTAPIsame as before 0. Preface

    • The samples mentioned in this article have exactly the same input and output as the samples mentioned in Note 1 and Note 2sampleMNISTAPI before . The difference is that the model is created in a different way.
      • sampleMNISTBy importing a caffe model and converting the caffe model into tensorrt form.
      • sampleMNISTAPIDirectly build the model layer by layer through the C++ interface of TensorRT, and import the weights in caffe into the created network.
      • sampleOnnxMNISTBuild the model through ONNX.
  • A little bit of doubt: TensorRT should use the ONNX model in two ways, one is to directly convert the ONNX model form in the program like this example, and the other is to first convert the ONNX model to an engine file through the official tool. Know the difference between these two methods.

1. ONNX model conversion

  • Other codes do not say, just look at the SampleOnnxMNIST::build()function.

1.1. Detailed build function

  • The process of building a network is basically
    • Build builder
    • Construct a blank network object
    • Build buildConfig parameters
    • Build Onnx model parser
    • Save the model structure in the network object through the parser
    • Set some model parameters (such as model quantification)
    • Validation results
bool SampleOnnxMNIST::build()
{
    
    
    // 构建模型builder
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
    {
    
    
        return false;
    }

    // 构建空白network对象
    const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
    if (!network)
    {
    
    
        return false;
    }

    // 创建BuildConfig,我也不知道是干啥用的
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
    {
    
    
        return false;
    }

    // 构建Onnx模型解析器
    auto parser
        = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
    if (!parser)
    {
    
    
        return false;
    }

    // 构建模型,通过parser解析,并将解析结果导入network中
    auto constructed = constructNetwork(builder, network, config, parser);
    if (!constructed)
    {
    
    
        return false;
    }

    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
        builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
    if (!mEngine)
    {
    
    
        return false;
    }

    // 验证结果
    assert(network->getNbInputs() == 1);
    mInputDims = network->getInput(0)->getDimensions();
    assert(mInputDims.nbDims == 4);
    assert(network->getNbOutputs() == 1);
    mOutputDims = network->getOutput(0)->getDimensions();
    assert(mOutputDims.nbDims == 2);

    return true;
}
  • The core of the previous step is to constrctNetworkparse the model through the parser and save it in the network
//!
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the
//!        output layers
//!
//! \param network Pointer to the network that will be populated with the Onnx MNIST network
//!
//! \param builder Pointer to the engine builder
//!
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
    SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
    SampleUniquePtr<nvonnxparser::IParser>& parser)
{
    
    
    // 注意,构建解析器的时候就已经把network对象作为参数传入了
    auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
        static_cast<int>(sample::gLogger.getReportableSeverity()));
    if (!parsed)
    {
    
    
        return false;
    }

    // 模型量化,不知道跟onnx_tensorrt工具有啥区别
    config->setMaxWorkspaceSize(16_MiB);
    if (mParams.fp16)
    {
    
    
        config->setFlag(BuilderFlag::kFP16);
    }
    if (mParams.int8)
    {
    
    
        config->setFlag(BuilderFlag::kINT8);
        samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
    }

    // 这里的 DLA 就是 Deep Learning Accelerator
    // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dla_layers
    samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);

    return true;
}

Guess you like

Origin blog.csdn.net/irving512/article/details/113701879