C++线程池的一种实现

    线程池是实际开发中提高软件性能和稳定性的一种基本手段。可以想一下,如果程序中不用多线程,那执行效率会很低,如果运行线程太多,操作系统又吃不消,程序性能和稳定性会收到威胁。所以使用线程池技术诞生了,其既可以提高程序执行效率,又能将其性能维护在一个可控的范围内。以下是C++线程池实现的一种方式。

      思路:根据线程最大个数和最小个数创建数量范围可控的线程,默认创建最小线程数量为CPU核数,最大为最小线程数的2倍。任务队列存放线程要执行的任务。类图如下:

目录

 1. 优先队列

 2.系统信息相关

3.辅助类

4.线程池类

5.测试代码

6.运行结果分析


 1. 优先队列

     由于封装的线程池支持执行优先级排序,所以执行任务使用std::priority_queue优先队列进行存储。std::priority_queue是stl中一种常用容器适配器,另外常见的还有 stack、queue,其是将std::vector或者std::deque使用堆结构进行封装。

头文件:<queue>
template<
    class T,
    class Container = std::vector<T>,
    class Compare = std::less<typename Container::value_type>

class priority_queue;

注意:第三个参数是元素比较器,当函数返回为true时就会替换堆顶元素。默认的std::less表明优先队列默认使用大根堆进行存储。

优先队列内部的实现需要依赖基础容器,该容器应可通过随机访问[i]和迭代器Iterator访问,并需要支持以下操作
empty( )
size( )
front( )
push_back( )
pop_back( )  
deque和vector这两个基础容器支持以上操作
所以在默认情况下,如果未为priority_queue指定基础容器类,则将使用vector。

关于std::priority_queue使用,有如下测试程序

#include <functional>
#include <queue>
#include <vector>
#include <iostream>
 
template<typename T> 
void print_queue(T& q) {
    while(!q.empty()) {
        std::cout << q.top() << " ";
        q.pop();
    }
    std::cout << '\n';
}

//大根堆比较器
struct MyCmp {
    bool operator()(const int& a, const int& b)
    {
        return a < b;
    }
};

//小根堆比较器
struct MyCmp1 {
    bool operator()(const int& a, const int& b)
    {
        return a > b;
    }
};

int main() {
    std::priority_queue<int> queue1;
 
    for(int n : {11,88,55,66,33,44,0,99,77,22}) {
        queue1.push(n);
    }
    print_queue(queue1);
 
    std::priority_queue<int, std::vector<int>, std::greater<int> > queue2;
 
    for(int n : {11,88,55,66,33,44,0,99,77,22}) {
        queue2.push(n);
    }
    print_queue(queue2);
 
    auto cmp = [](const int &left, const int &right) {return left < right;};
    //decltype 是 C++11 新增的一个关键字,它和 auto 的功能一样,都用来在编译时期进行自动类型推导。
    std::priority_queue<int, std::vector<int>, decltype(cmp)> queue3(cmp);
    for(int n : {11,88,55,66,33,44,0,99,77,22}) {
        queue3.push(n);
    }
    print_queue(queue3);

    std::priority_queue<int, std::vector<int>, MyCmp> queue4;
    for(int n : {11,88,55,66,33,44,0,99,77,22}) {
        queue4.push(n);
    }
    print_queue(queue4);
    
    std::priority_queue<int, std::vector<int>, MyCmp1> queue5;
    for(int n : {11,88,55,66,33,44,0,99,77,22}) {
        queue5.push(n);
    }
    print_queue(queue5);
}

运行结果如下:

 

2.系统信息相关

以下类是为了获取系统CPU核数,物理内存数,CPU架构信息的类。

sysinfo.h

#pragma once

#include <string>
#include "uncopyable.h"

class SysInfo {

    DECLARE_STATIC(SysInfo);

public:
    /// get logic cpu number
    static uint32_t get_logical_cpu_number();

    static uint64_t get_physical_memory_size();

    static bool get_kernel_info(std::string& kernerl_info);

    static bool get_machine_architecture(std::string& arch_info);
};

sysinfo.cpp

#include <unistd.h>
#include <sys/utsname.h>

#include "sysinfo.h"

/// internal
static const struct utsname* get_uname_info()
{
    static struct utsname uts;
    if (uname(&uts) == -1)
        return NULL;
    return &uts;
}

/// SysInfo interfaces
uint32_t SysInfo::get_logical_cpu_number()
{
    return sysconf(_SC_NPROCESSORS_ONLN);
}

uint64_t SysInfo::get_physical_memory_size()
{
    return sysconf(_SC_PHYS_PAGES) * (uint64_t)sysconf(_SC_PAGESIZE);
}

bool SysInfo::get_kernel_info(std::string& kernerl_info)
{
    const struct utsname* uts = get_uname_info();
    if (uts != NULL) {
        kernerl_info.assign(uts->release);
        return true;
    }
    return false;
}

bool SysInfo::get_machine_architecture(std::string& arch_info)
{
    const struct utsname* uts = get_uname_info();
    if (uts != NULL) {
        arch_info.assign(uts->machine);
        return true;
    }
    return false;
}

3.辅助类

该辅助类是为了阻止类拷贝的工具类。

uncopyable.h

#pragma once

namespace uncopyable
{

class Uncopyable
{
protected:
    Uncopyable() {}
    ~Uncopyable() {}

private:
    Uncopyable(const Uncopyable&);
    const Uncopyable& operator=(const Uncopyable&);
};

} 

typedef uncopyable::Uncopyable Uncopyable;

#define DECLARE_UNCOPYABLE(Class) \
private: \
    Class(const Class&); \
    Class& operator=(const Class&)

#define DECLARE_STATIC(Class) \
private: \
	Class() = delete; \
	~Class() = delete

4.线程池类

线程池实现类:

threadpool.h

#pragma once

#include <map>
#include <memory>
#include <queue>
#include <vector>
#include <list>
#include <string>
#include <functional>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <atomic>

#include "uncopyable.h"

class ThreadPool
{
    DECLARE_UNCOPYABLE(ThreadPool);

public:
    typedef std::function<void ()> TaskHandler;

public:
    explicit ThreadPool(
            int min_num_threads = -1,
            int max_num_threads = -1,
            const std::string& name = "");
    ~ThreadPool();

public:
    //设置线程最小个数
    void set_min_thread_number(int num_threads);

    //设置线程最大个数
    void set_max_thread_number(int num_threads);

    //增加线程任务
    uint64_t add_task(const TaskHandler& callback);

    uint64_t add_task(const TaskHandler& callback, int priority);

    //取消任务
    bool cancel_task(uint64_t task_id);

    //终止
    void terminate();

    void clear_tasks();

    void wait_for_idle();

    void get_stats() const;

private:
    struct Task
    {
        Task(const TaskHandler& entry,
                uint64_t id,
                int priority) :
            on_schedule(entry),
            id(id),
            priority(priority),
            is_canceled(false) {}
        
        void set_schedule_timeout_flag(bool flag)
        {
            is_canceled = flag;
        }

        void set_cancel_flag(bool flag)
        {
            is_canceled = flag;
        }

        bool check_cancel_flag() const
        {
            return is_canceled;
        }

        TaskHandler on_schedule;
        uint64_t id; // task id
        int priority; // task priority, lower is better
        bool is_canceled; // task flag, schedule timeout or canceled
        std::mutex task_lock; // task inner lock
    };

    struct ThreadContext
    {
        ThreadContext() :
            waiting_timer_id(0),
            is_waiting_timeout(false) {}

        void set_waiting_timeout_flag(bool flag)
        {
            is_waiting_timeout = flag;
        }

        bool check_waiting_timeout_flag() const
        {
            return is_waiting_timeout;
        }

        std::shared_ptr<std::thread> thread;
        uint64_t waiting_timer_id;
        bool is_waiting_timeout; 
    };

    typedef std::map<uint64_t, std::shared_ptr<Task> > TaskMap;
    typedef std::list<ThreadContext*> ThreadList;

private:
    void work(ThreadContext* thread);

private:
    /// @brief auto generate a new task id
    uint64_t new_task_id();

    uint64_t add_task_internal(const TaskHandler& callback,
            int priority);

    bool dequeue_task_in_lock(std::shared_ptr<Task>& task);

    bool need_new_thread() const;

    bool need_shrink_thread() const;

    // expand a thread into pool
    void expand_thread();

private:
    struct TaskCompare
    {
        bool operator()(const std::shared_ptr<Task>& a,
                const std::shared_ptr<Task>& b)
        {
            return a->priority > b->priority;
        }
    };

private:
    int m_min_num_threads;
    int m_max_num_threads;

    // current threads number
    std::atomic<int>    m_num_threads;

    // current on-busy threads number
    std::atomic<int>  m_num_busy_threads;

      // tasks container
    TaskMap m_tasks;

    // exit flags
    volatile bool m_exit;

    mutable std::mutex m_lock;

    // all threads are free now
    std::condition_variable m_exit_cond;

    // tasks been requested
    std::condition_variable m_task_cond;

    // not suitable for std::vector since uncopyable structures
    // list with m_num_threads elements running with each work thread routine
    ThreadList m_threads;

    // tasks queue,此处优先队列使用小根堆,priority越小越先被执行
    std::priority_queue<
        std::shared_ptr<Task>,
        std::vector<std::shared_ptr<Task> >,
        TaskCompare> m_task_queue;
};

threadpool.cpp

#include <cstddef>
#include <iostream>
#include <string>
#include <atomic>
#include <chrono>

#include "threadpool.h"
#include "sysinfo.h"

using namespace std;

static uint64_t s_thread_name_index = 0;

ThreadPool::ThreadPool(
        int min_num_threads, int max_num_threads, const std::string& name) :
    m_min_num_threads(0),
    m_max_num_threads(0),
    m_num_threads(0),
    m_num_busy_threads(0),
    m_exit(false)
{
    if (min_num_threads <= 0) {
        m_min_num_threads = SysInfo::get_logical_cpu_number();
    } else {
        m_min_num_threads = min_num_threads;
    }
    if (max_num_threads < m_min_num_threads) {
        m_max_num_threads = 2 * m_min_num_threads;
    } else {
        m_max_num_threads = max_num_threads;
    }

    std::unique_lock<std::mutex> lock(m_lock);
    for (int i = 0; i < m_min_num_threads; i++) {
        ThreadContext* thread = new ThreadContext();
        thread->thread.reset(
                new std::thread(std::bind(
                        &ThreadPool::work, this, thread)));
        m_threads.push_back(thread);
        m_num_threads++;
    }
    s_thread_name_index += m_min_num_threads;

}

ThreadPool::~ThreadPool()
{
    terminate();
}

void ThreadPool::terminate()
{
    std::cout << __LINE__ << " " << __FUNCTION__ << std::endl;
    if (m_exit) {
        return;
    }

    {
        std::unique_lock<std::mutex> lock(m_lock);
        m_exit = true;

        // send signal to all busy threads to exit
        for (int i = 0; i < m_num_threads; i++) {
            m_task_cond.notify_all();
        }

        // wait until all busy threads exited
        while (m_num_busy_threads > 0) {
            m_exit_cond.wait(lock);
        }
    }

    for (auto it = m_threads.begin(); it != m_threads.end(); it++) {
        if ((*it)->thread->joinable()) {
            (*it)->thread->join();
        }
    }


    // threads clear
    while (!m_threads.empty()) {
        ThreadContext* thread = m_threads.front();
        m_threads.pop_front();
        thread->thread.reset();
        delete thread; 
    }

    m_num_threads = 0;
    m_num_busy_threads = 0;

    // tasks clear
    clear_tasks();
}

void ThreadPool::clear_tasks()
{
    std::unique_lock<std::mutex> lock(m_lock);
    while (!m_task_queue.empty()) {
        m_task_queue.pop();
    }
}

void ThreadPool::wait_for_idle()
{
    if (m_exit) {
        return;
    }

    for (;;) {
        {
            std::unique_lock<std::mutex> lock(m_lock);
            if (m_task_queue.empty() && m_num_busy_threads == 0) {
                return;
            }
        }
        //阻塞1秒
        this_thread::sleep_for(chrono::seconds(1));
    }
}

void ThreadPool::set_min_thread_number(int num_threads)
{
    if (m_exit) {
        return;
    }
    if (num_threads <= 0) {
        m_min_num_threads = SysInfo::get_logical_cpu_number();
    } else {
        m_min_num_threads = num_threads;
    }
}

void ThreadPool::set_max_thread_number(int num_threads)
{
    if (m_exit) {
        return;
    }
    if (num_threads < m_min_num_threads) {
        m_max_num_threads = 2 * m_min_num_threads;
    } else {
        m_max_num_threads = num_threads;
    }
}

uint64_t ThreadPool::add_task(const TaskHandler& callback)
{
    return add_task_internal(callback, 10);
}

uint64_t ThreadPool::add_task(const TaskHandler& callback, int priority)
{
    return add_task_internal(callback, priority);
}

bool ThreadPool::cancel_task(uint64_t task_id)
{
    std::unique_lock<std::mutex> lock(m_lock);
    TaskMap::iterator it = m_tasks.find(task_id);
    if (it != m_tasks.end()) {
        it->second->set_cancel_flag(true);
        return true;
    } else {
        return false;
    }
}

void ThreadPool::get_stats() const
{
    std::unique_lock<std::mutex> lock(m_lock);
    std::cout << "######## ThreadPool Stats ################" << std::endl;
    std::cout << "m_min_num_threads:" << m_min_num_threads << std::endl;
    std::cout << "m_max_num_threads:" << m_max_num_threads << std::endl;
    std::cout << "m_num_threads:" << m_num_threads << std::endl;
    std::cout << "m_num_busy_threads:" << m_num_busy_threads << std::endl;
    std::cout << "m_threads size:" << m_threads.size() << std::endl;
    std::cout << "##########################################" << std::endl;
}


// working threads logic
void ThreadPool::work(ThreadContext* thread)
{
    m_num_busy_threads++;
    for (;;) {
        std::shared_ptr<Task> task;
        {
            std::unique_lock<std::mutex> lock(m_lock);
            if (m_exit || thread->check_waiting_timeout_flag()) {
                break;
            }

            if (!dequeue_task_in_lock(task)) {
                m_num_busy_threads--;
                m_task_cond.wait(lock);
                m_num_busy_threads++;
                continue;
            }
        }

        if (!task) {
            continue;
        }

        // execute task
        if (task->on_schedule) {
            task->on_schedule();
        }
    }

    // quit native-thread, move to freethread list
    {
        std::unique_lock<std::mutex> lock(m_lock);
        m_num_threads--;
        m_num_busy_threads--;
        if (m_num_busy_threads == 0) {
            m_exit_cond.notify_all();
        }
    }
}

/// private methods
static std::atomic<size_t> s_task_id(0);

uint64_t ThreadPool::new_task_id()
{
    return static_cast<uint64_t>(++s_task_id);
}

uint64_t ThreadPool::add_task_internal(const TaskHandler& callback,
        int priority)
{
    if (m_exit) {
        return 0;
    }

    uint64_t id = new_task_id();
    std::shared_ptr<Task> task(new Task(callback, id, priority));

    {
        std::unique_lock<std::mutex> lock(m_lock);
        // check whether need expand threads
        if (need_new_thread()) {
            expand_thread(); 
        }
        
        // add to task map
        m_tasks[id] = task;

        // push into priority task queue
        m_task_queue.push(task);
        m_task_cond.notify_all();
    }
    
    return id;
}

bool ThreadPool::need_new_thread() const
{
    if (m_num_threads >= m_max_num_threads) {
        return false;
    }
    if (m_num_threads < m_min_num_threads ||
            m_num_threads == m_num_busy_threads) {
        return true;
    }
    return false;
}

bool ThreadPool::need_shrink_thread() const
{
    if (m_num_threads > m_min_num_threads) {
        return true;
    }
    return false;
}

void ThreadPool::expand_thread()
{
    ThreadContext* thread = new ThreadContext();
    // stored in threadcontext.thread use scoped_ptr
    thread->thread.reset(
            new std::thread(std::bind(
                    &ThreadPool::work, this, thread)));
    // add threads into busythreads list 
    m_threads.push_back(thread);
    m_num_threads++;
}

bool ThreadPool::dequeue_task_in_lock(std::shared_ptr<Task>& task)
{
    if (m_task_queue.empty()) {
        return false;
    }

    task = m_task_queue.top();
    // remove from task queue
    m_task_queue.pop();

    // remove task map
    m_tasks.erase(task->id);

    return true;
}

5.测试代码

#include <iostream>
#include <ctime>
 #include <cstdlib>
 #include <chrono>
#include "threadpool.h"

using namespace std;

class HttpClient {
public:
    static HttpClient* getInstance() {
        static HttpClient s_instance;
        return &s_instance;
    }
    ~HttpClient() {
    }

    bool ConnServer() {
        int iReqId = 1;
        int timeout_in_ms = 100;
        int call_timeout_in_ms = 100;

        auto execute_fun = [this, iReqId, timeout_in_ms]() {
            std::cout << "HttpClient:: begin " << iReqId << " " << timeout_in_ms << std::endl;
            this->UseTimeFun();
            std::cout << "HttpClient:: end" << std::endl;
            m_thread_pool.get_stats();
        };

        iReqId++;
        auto execute_fun1 = [this, iReqId, timeout_in_ms]() {
            std::cout << "HttpClient:: begin " << iReqId << " " << timeout_in_ms << std::endl;
            this->UseTimeFun();
            std::cout << "HttpClient:: end" << std::endl;
            m_thread_pool.get_stats();
        };

        m_thread_pool.add_task(execute_fun, 1);
        m_thread_pool.add_task(execute_fun1, 0);

        return true;
    }

private:
    void UseTimeFun() {
        srand((unsigned)time(NULL));
        this_thread::sleep_for(chrono::seconds(rand()%10));
    }

private:
    ThreadPool m_thread_pool;
};

class DBClient {
public:
    static DBClient* getInstance() {
        static DBClient s_instance;
        return &s_instance;
    }
    ~DBClient() {
    }

    bool ConnServer() {
        std::string uname = "mysql";
        std::string passwd = "mysql";
        int timeout_in_ms = 100;
        int call_timeout_in_ms = 100;

         auto execute_fun = [this, uname, passwd]() {
            std::cout << "DBClient:: begin " << uname << " " << passwd << std::endl;
            this->OperFun();
            std::cout << "DBClient:: end " << std::endl;
            m_thread_pool.get_stats();
        };

        m_thread_pool.add_task(execute_fun, 1);
        return true;
    }

private:
    void OperFun() {
        srand((unsigned)time(NULL));
        this_thread::sleep_for(chrono::seconds(rand()%10));
    }
    
private:
    ThreadPool  m_thread_pool;
};


int main(int argc, char* argv[]) {
    HttpClient::getInstance()->ConnServer();
    DBClient::getInstance()->ConnServer();
    getchar();
    return 0;
}

Makefile

app: useThreadPool

#说明:$^代表依赖项
useThreadPool: main.cpp threadpool.cpp sysinfo.cpp  
	g++ -g $^ -o useThreadPool -lpthread

clean:
	-rm useThreadPool -f

6.运行结果分析

 

        需要开启线程的函数都是执行比较耗时的,测试程序使用sleep来模拟http请求和连接数据库操作,并将其执行放入线程池中执行。
通过运行可以看出,本人电脑12核,所以默认最小线程和最大线程为12,24。然后启动最小数量线程开始工作,默认忙碌线程数量为0,
HttpClient向其中加入耗时任务execute_fun和execute_fun1后,忙碌线程数量变成2,DBClient向其中加入耗时任务

猜你喜欢

转载自blog.csdn.net/hsy12342611/article/details/128509200