[CV学习笔记] yolo&tensorrt多线程推理-第三部分

1、前言

在前两篇博客中学习了yolo&tensorrt推理的代码,学习了模型的加载、预处理、后处理等等,本文将继续学习其中的消费者、生产者推理代码
yolo&tensorrt项目:https://github.com/shouxieai/infer
第一部分学习记录:https://blog.csdn.net/weixin_42108183/article/details/129411759
第二部分学习记录:https://blog.csdn.net/weixin_42108183/article/details/129455120

2、多线程代码学习

cpm.hpp

// 结果类型、输入类型、模型类型
template <typename Result, typename Input, typename Model>
class Instance{
    
    
protected:
    // 任务形式
    struct Item
    {
    
    
        Input input; // 输入
        std::shared_ptr<std::promise<Result>> pro; // 输入对应的输出
    };
    
    std::condition_variable cond_;  // 条件变量
    std::queue<Item> input_queue_;  // 任务队列
    std::shared_ptr<std::thread> worker_;  // 工作线程
    volatile bool run_ = false;     // 是否运行
public:
    // 析构函数中使用stop回收子线程,即推理线程
    virtual ~Instance() {
    
     stop(); }

    void stop(){
    
    
        run_ = false; // 将运行状态设置为false
        cond_.notify_one();  // 唤醒工作线程、如果工作线程检测到run_=false,则会退出。
        {
    
    
            std::unique_lock<std::mutex> l(queue_lock_);
            // 如果工作队列非空,继续推理完毕
            while (!input_queue_.empty()){
    
    
                auto &item = input_queue_.front();
                if (item.pro)
                    item.pro->set_value(Result());
                input_queue_.pop();
            }

        }
        // 回收工作线程
        if (worker_)
        {
    
    
            worker_->join();
            worker_.reset();
        }
    }
    
    // 推理单张图像
    virtual std::shared_future<Result> commit(const Input &input){
    
    
        // 构建任务    
        Item item;
        item.input = input;
        item.pro.reset(new std::promise<Result>());
        {
    
    
            // 给任务队列添加任务
            std::unique_lock<std::mutex> __lock_(queue_lock_);
            input_queue_.push(item);
        }
        // 唤醒线程
        cond_.notify_one();
        return item.pro->get_future();
    }
    
    // 推理多张图片
    commits(const std::vector<Input> &inputs){
    
    }

    template <typename LoadMethod>
    bool start(...){
    
    
        // 先stop,保证没有子线程
        stop();
        std::promise<bool> status;  // 模型是否被正确加载的标识
        // 创建工作线程
        worker_ = std::make_shared<std::thread>(&Instance::worker<LoadMethod>, this,
                                              std::ref(loadmethod), std::ref(status));
        return status.get_future().get();  // 阻塞等待 status被设置某个值
    }
private:
    template <typename LoadMethod>
    void worker(const LoadMethod &loadmethod, std::promise<bool> &status){
    
    
        // 加载模型
        std::shared_ptr<Model> model = loadmethod();
        if (model == nullptr)
        {
    
       
            // 加载失败返回false
            status.set_value(false);
            return;
        }    
        // 等待解锁
        while (get_items_and_wait(fetch_items, max_items_processed_)){
    
    
            inputs.resize(fetch_items.size());
            std::transform(fetch_items.begin(), fetch_items.end(), inputs.begin(),
                               [](Item &item)
                               {
    
     return item.input; });
            // 推理
            auto ret = model->forwards(inputs, stream_);
        }
        model.reset();  // 智能指针,释放模型
        run_ = false;   
    }
    virtual bool get_items_and_wait(std::vector<Item> &fetch_items,int max_size){
    
    
        // 是否解锁
        std::unique_lock<std::mutex> l(queue_lock_);
        // run_=false 或者 任务队列不为空,则解锁
        cond_.wait(l, [&](){
    
     return !run_ || !input_queue_.empty(); });           
        for (int i = 0; i < max_size && !input_queue_.empty(); ++i)
        {
    
       
            // 添加任务
            fetch_items.emplace_back(std::move(input_queue_.front()));
            input_queue_.pop();
        }
        return true;
    }
    // 推理单张图片时使用
    virtual bool get_item_and_wait(Item &fetch_item){
    
    
        
    }
}

3、测试代码

main.cpp

void perf()
{
    
    
    // 将图片添加到batchsize
    for (int i = images.size(); i < batch; ++i)
        images.push_back(images[i % 3]);
    // 初始化消费者、生产者实例
    cpm::Instance<yolo::BoxArray, yolo::Image, yolo::Infer> cpmi;
    // start中会开辟子线程用于加载engine,此时任务队列空,子线程会阻塞
    bool ok = cpmi.start(
        []
        {
    
     return yolo::load("yolov8n.transd.engine", yolo::Type::V8); },
        max_infer_batch);
    // 将images进行cvimg 的操作,然后储存到yoloimages中
    td::transform(images.begin(), images.end(), yoloimages.begin(), cvimg);
    
    for (int i = 0; i < 5; ++i)
    {
    
    
        timer.start();
        // 提交图片,.get()会阻塞等待子线程返回结果
        cpmi.commits(yoloimages).back().get();
        timer.stop("BATCH16");
    }
    ...
}

4、总结

本次学习了基于生产者、消费者的多线程推理代码,并且很容易的将代码融入到项目中。

猜你喜欢

转载自blog.csdn.net/weixin_42108183/article/details/129456369