ONNX是一种用于表示机器学习模型的格式,而TensorRT是一个高性能的推理引擎,用于在NVIDIA GPU上进行推理。自定义plugin则是指在TensorRT中自定义一些操作(如卷积、ReLU等),以提高模型推理效率。下面是转换ONNX模型到TensorRT并使用自定义plugin的详细步骤:
-
定义自定义plugin
首先,需要定义一个继承自ITensorRT接口的plugin类,实现其对应的虚函数,例如前向计算函数和反向传播函数等。可以使用C++或Python来编写此类。然后,使用PluginRegistry类将其添加到TensorRT中。 -
转换ONNX模型到TensorRT
使用TensorRT的Python API或C++ API,将ONNX模型转换为TensorRT引擎。这可以通过以下步骤完成:
- 使用OnnxParser类或OnnxConfig类解析ONNX模型。OnnxConfig类允许设置TensorRT引擎和推理的各种配置参数。
- 创建一个Builder对象,该对象用于构建TensorRT引擎。
- 通过Builder对象创建一个Network对象,该对象用于构建网络结构。
- 将解析的ONNX模型添加到Network对象中。
- 创建一个ICudaEngine对象,其通过Builder对象和Network对象构建TensorRT引擎。
- 应用自定义plugin
使用TensorRT的C++ API或Python API,将自定义plugin应用到TensorRT引擎。可以使用以下步骤实现:
- 通过ICudaEngine对象获取IPluginRegistry对象,并使用registerPlugin()方法将自定义plugin注册到TensorRT中。
- 创建一个INetworkDefinition对象,该对象用于构建网络结构。
- 通过ICudaEngine对象获取一个IExecutionContext对象,用于执行推理。
- 使用INetworkDefinition对象创建一个输入和输出张量,并将其绑定到IExecutionContext对象中。
- 执行推理。
在实现以上步骤时,需要注意TensorRT的版本和系统配置。建议在NVIDIA GPU上使用最新版本的TensorRT,以获得最佳性能和功能。
int onnx_with_plugin_create_engine(std::string root_dir)
{
std::string onnx_file = root_dir+"model.onnx";
std::string modeltrt = root_dir+"model.trt";
std::string plugin_file = root_dir+"libvit_plugin.so";
std::fstream trtCache(modeltrt, std::ifstream::in);
nvinfer1::ICudaEngine* engine_ = nullptr;
// Load plugin library
void* pluginLibrary = dlopen(plugin_file.c_str(), RTLD_LAZY);
if (!pluginLibrary) {
std::cerr << "ERROR: Could not load plugin dynamic library" << std::endl;
return EXIT_FAILURE;
}
// Register plugin factory with TensorRT
auto creator = getPluginRegistry()->getPluginCreator("TransformerPlugin", "1");
if (!creator) {
std::cerr << "Failed to find plugin creator." << std::endl;
return EXIT_FAILURE;
}
if (!trtCache.is_open())
{
std::cout << "Building TRT engine." << std::endl;
// define builder
auto builder = (nvinfer1::createInferBuilder(gLogger));
// define network
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = (builder->createNetworkV2(explicitBatch));
// define onnxparser
auto parser = (nvonnxparser::createParser(*network, gLogger));
if (!parser->parseFromFile(onnx_file.data(), static_cast<int>(nvinfer1::ILogger::Severity::kWARNING)))
{
std::cerr << ": failed to parse onnx model file, please check the onnx version and trt support op!"
<< std::endl;
exit(-1);
}
// define config
auto networkConfig = builder->createBuilderConfig();
// setFlag FP16
// networkConfig->setFlag(nvinfer1::BuilderFlag::kFP16);
// std::cout << "Enable fp16!" << std::endl;
// set max batch size
builder->setMaxBatchSize(1);
// set max workspace
networkConfig->setMaxWorkspaceSize(size_t(1) << 30);
engine_ = (builder->buildEngineWithConfig(*network, *networkConfig));
if (engine_ == nullptr)
{
std::cerr << ": engine init null." << std::endl;
exit(-1);
}
// serialize the engine, then close everything down
auto trtModelStream = (engine_->serialize());
std::fstream trtOut(modeltrt, std::ifstream::out);
if (!trtOut.is_open())
{
std::cerr << "can't store trt engine.\n";
exit(-1);
}
trtOut.write((char*)trtModelStream->data(), trtModelStream->size());
trtOut.close();
trtModelStream->destroy();
networkConfig->destroy();
parser->destroy();
network->destroy();
builder->destroy();
std::cerr << "build engine done." << std::endl;
}
else
{
std::cout << "Load engine: " << modeltrt << std::endl;
std::ifstream engineFile(modeltrt, std::ios::binary);
long int fsize = 0;
engineFile.seekg(0, engineFile.end);
fsize = engineFile.tellg();
engineFile.seekg(0, engineFile.beg);
std::vector<char> engineString(fsize);
engineFile.read(engineString.data(), fsize);
if (engineString.size() == 0)
{
std::cout << "Failed getting serialized engine!" << std::endl;
exit(-1);
}
std::cout << "Succeeded getting serialized engine." << std::endl;
nvinfer1::IRuntime* runtime {
createInferRuntime(gLogger)};
// safe::IRuntime *runtime {safe::createInferRuntime(gLogger)}; // 使用 safe runtime
engine_ = runtime->deserializeCudaEngine(engineString.data(), fsize);
if (engine_ == nullptr)
{
std::cerr << "Failed loading engine." << std::endl;
exit(-1);
}
std::cerr << "Succeeded loading engine." << std::endl;
engineFile.close();
}
// inference
// Step 2
int inputSize = 1 * 3 * 1152 * 1152;
int outputSize = 1 * 144 * 144;
std::vector<float> inputBuffer(inputSize);
std::vector<int32_t> outputBuffer(outputSize);
// Step 3
cudaSetDevice(0);
cudaFree(0);
cudaStream_t stream;
cudaStreamCreate(&stream);
void *d_inputBuffer = nullptr;
cudaMalloc(&d_inputBuffer, inputSize*sizeof(float));
void *d_outputBuffer = nullptr;
cudaMalloc(&d_outputBuffer, outputSize*sizeof(int32_t));
// Step 4
nvinfer1::IExecutionContext* context = engine_->createExecutionContext();
if (!context)
{
std::cerr << "Failed to create execution context" << std::endl;
return 1;
}
// Step 5
for (int i = 0; i < inputSize; ++i) {
inputBuffer[i] = i % 255;
}
cudaMemcpyAsync(d_inputBuffer, inputBuffer.data(), inputSize * sizeof(float), cudaMemcpyHostToDevice, stream);
// Step 6
void *buffers[] = {
d_inputBuffer, d_outputBuffer};
context->enqueueV2(buffers, stream, nullptr);
// Step 7
cudaMemcpyAsync(outputBuffer.data(), d_outputBuffer, outputSize * sizeof(int32_t), cudaMemcpyDeviceToHost, stream);
for (size_t i = 0; i < 30; i++)
{
if (outputBuffer[i] > 0) /* code */
{
std::cerr << outputBuffer[i] << std::endl;
}
}
context->destroy();
engine_->destroy();
cudaFree(d_inputBuffer);
cudaFree(d_outputBuffer);
}