1. 前言
随着深度学习模型在计算机视觉任务中的应用日益广泛,深度估计任务作为其中的重要一环,广泛应用于自动驾驶、机器人和增强现实等领域。Depth Anything V2 是一种高效的深度估计模型,其推理速度和准确性在众多应用场景中表现出色。然而,为了在实际应用中更好地发挥其性能,我们可以通过将模型部署在 TensorRT 上,从而进一步加速推理过程。
本文将介绍如何将 Depth Anything V2 模型转换为 TensorRT 格式,并在实际应用中进行推理。
2. 为什么选择 TensorRT
TensorRT 是由 NVIDIA 开发的一款高性能推理引擎,专门为深度学习模型的高效部署而设计。其主要优点包括:
- 加速推理:通过对模型进行量化(如 FP16、INT8)和内存优化,TensorRT 能够显著加快推理速度。
- GPU 加速:TensorRT 充分利用 NVIDIA GPU 的计算能力,尤其是在边缘设备或实时应用中,能够在保持高精度的前提下最大化推理效率。
- 跨平台支持:TensorRT 支持多种平台,适合在不同的硬件环境中进行推理。
因此,通过将 Depth Anything V2 模型部署到 TensorRT 中,我们可以获得极大的推理性能提升,满足实时深度估计的需求。
3. 环境准备
在进行 TensorRT 推理之前,需要确保环境配置正确,以下是本次实验的具体配置要求和说明:
环境配置
- 操作系统:本文在 Windows 11 上进行测试。
- CUDA 版本:使用 CUDA 11.7,请确保安装的 CUDA 版本与 TensorRT 兼容。
- TensorRT:安装适合当前 CUDA 版本的 TensorRT。可以通过 NVIDIA 官方网站 下载适配的版本。安装过程中如果遇到 pycuda 问题,可以参考相关教程解决,或者联系我获取帮助。
- GPU:本文使用 RTX 3060(8GB 显存)进行推理。
此外,您需要准备 ONNX 格式的 Depth Anything V2 模型文件,本文将演示如何将其转换为 TensorRT 格式进行推理。
要验证 TensorRT 是否正确安装并检查版本,可以运行以下 Python 代码:
import tensorrt as trt
print(trt.__version__)
我的程序输出结果为:
8.6.0
确保 TensorRT 版本正确安装后,就可以开始准备推理过程。
4. 模型转换
4.1 ONNX 转换 TensorRT 格式
您可以通过以下 百度网盘链接 下载已经导出的 Depth Anything V2 模型的 ONNX 文件。
共有三个深度估计模型可供选择:
- depth_anything_v2_vitb.onnx:模型较大,精度较高,但推理时间较长。
- depth_anything_v2_vitl.onnx:平衡精度和推理速度的中等模型。
- depth_anything_v2_vits.onnx:模型较小,推理速度最快,但精度相对较低。
import tensorrt as trt
import warnings
import os
warnings.simplefilter("ignore", category=DeprecationWarning)
class EngineBuilder:
def __init__(
self,
onnx_file_path,
save_path,
mode,
log_level="ERROR",
max_workspace_size=1,
strict_type_constraints=False,
int8_calibrator=None,
**kwargs,
):
"""build TensorRT model from onnx model.
Args:
onnx_file_path (string or io object): onnx model name
save_path (string): tensortRT serialization save path
mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
log_level (string, default is ERROR): tensorrt logger level, now
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
max_workspace_size (int, default is 1):
The maximum GPU temporary memory which the ICudaEngine can use at
execution time. default is 1GB.
strict_type_constraints (bool, default is False):
When strict type constraints is set, TensorRT will choose
the type constraints that conforms to type constraints.
If the flag is not enabled higher precision
implementation may be chosen if it results in higher performance.
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None):
calibrator for int8 mode,
if None, default calibrator will be used as calibration data."""
self.onnx_file_path = onnx_file_path
self.save_path = save_path
self.mode = mode.lower()
assert self.mode in [
"fp32",
"fp16",
"int8",
], f"mode should be in ['fp32', 'fp16', 'int8'], but got {
mode}"
self.trt_logger = trt.Logger(getattr(trt.Logger, log_level))
self.builder = trt.Builder(self.trt_logger)
self.network = None
self.max_workspace_size = max_workspace_size
self.strict_type_constraints = strict_type_constraints
self.int8_calibrator = int8_calibrator
def create_network(self, **kwargs):
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
self.network = self.builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(self.network, self.trt_logger)
if isinstance(self.onnx_file_path, str):
with open(self.onnx_file_path, "rb") as f:
print("Beginning ONNX file parsing")
flag = parser.parse(f.read())
else:
flag = parser.parse(self.onnx_file_path.read())
if not flag:
for error in range(parser.num_errors):
print(parser.get_error(error))
print("Completed parsing of ONNX file.")
# re-order output tensor
output_tensors = [
self.network.get_output(i) for i in range(self.network.num_outputs)
]
[self.network.unmark_output(tensor) for tensor in output_tensors]
for tensor in output_tensors:
identity_out_tensor = self.network.add_identity(tensor).get_output(0)
identity_out_tensor.name = "identity_{}".format(tensor.name)
self.network.mark_output(tensor=identity_out_tensor)
def create_engine(self):
config = self.builder.create_builder_config()
config.max_workspace_size = self.max_workspace_size * (1 << 25)
if self.mode == "fp16":
assert self.builder.platform_has_fast_fp16, "not support fp16"
config.set_flag(trt.BuilderFlag.FP16)
# builder.fp16_mode = True
if self.mode == "int8":
assert self.builder.platform_has_fast_int8, "not support int8"
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = self.int8_calibrator
# builder.int8_mode = True
# builder.int8_calibrator = int8_calibrator
if self.strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, True)
print(
f"Building an engine from file {
self.onnx_file_path}; this may take a while..."
)
profile = self.builder.create_optimization_profile()
config.add_optimization_profile(profile)
engine = self.builder.build_engine(self.network, config)
print("Create engine successfully!")
print(f"Saving TRT engine file to path {
self.save_path}")
with open(self.save_path, "wb") as f: