推荐一个C++线程池实现

Table of Contents

    A simple C++11 Thread Pool implementation

    使用起来非常方便:

    // create thread pool with 4 worker threads
    ThreadPool pool(4);
    
    // enqueue and store future
    auto result = pool.enqueue([](int answer) { return answer; }, 42);
    
    // get result from future
    std::cout << result.get() << std::endl;
    
    

    如果中间要等待任务完全执行完成使用下面两个函数:

    pool.wait_until_empty();  
    pool.wait_until_nothing_in_flight ();
    

    源码:

    // -*- C++ -*-
    // Copyright (c) 2012-2015 Jakob Progsch
    //
    // This software is provided 'as-is', without any express or implied
    // warranty. In no event will the authors be held liable for any damages
    // arising from the use of this software.
    //
    // Permission is granted to anyone to use this software for any purpose,
    // including commercial applications, and to alter it and redistribute it
    // freely, subject to the following restrictions:
    //
    //    1. The origin of this software must not be misrepresented; you must not
    //    claim that you wrote the original software. If you use this software
    //    in a product, an acknowledgment in the product documentation would be
    //    appreciated but is not required.
    //
    //    2. Altered source versions must be plainly marked as such, and must not be
    //    misrepresented as being the original software.
    //
    //    3. This notice may not be removed or altered from any source
    //    distribution.
    //
    // Modified for log4cplus, copyright (c) 2014-2015 Václav Zeman.
    
    #ifndef THREAD_POOL_H_7ea1ee6b_4f17_4c09_b76b_3d44e102400c
    #define THREAD_POOL_H_7ea1ee6b_4f17_4c09_b76b_3d44e102400c
    
    #include <vector>
    #include <queue>
    #include <memory>
    #include <thread>
    #include <mutex>
    #include <condition_variable>
    #include <future>
    #include <atomic>
    #include <functional>
    #include <stdexcept>
    #include <algorithm>
    #include <cassert>
    
    
    namespace progschj {
    
    class would_block
        : public std::runtime_error
    {
        using std::runtime_error::runtime_error;
    };
    
    
    class ThreadPool {
    public:
        template <typename F, typename... Args>
        using return_type =
    #if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L
                typename std::invoke_result<F&&, Args&&...>::type;
    #else
                typename std::result_of<F&& (Args&&...)>::type;
    #endif
    
        explicit ThreadPool(std::size_t threads
            = (std::max)(2u, std::thread::hardware_concurrency()));
        template <typename F, typename... Args>
        auto enqueue_block(F&& f, Args&&... args) -> std::future<return_type<F, Args...>>;
        template <typename F, typename... Args>
        auto enqueue(F&& f, Args&&... args) -> std::future<return_type<F, Args...>>;
        void wait_until_empty();
        void wait_until_nothing_in_flight();
        void set_queue_size_limit(std::size_t limit);
        void set_pool_size(std::size_t limit);
        ~ThreadPool();
    
    private:
        void start_worker(std::size_t worker_number,
            std::unique_lock<std::mutex> const &lock);
    
        template <typename F, typename... Args>
        auto enqueue_worker(bool, F&& f, Args&&... args) -> std::future<return_type<F, Args...>>;
    
        template <typename T>
        static std::future<T> make_exception_future (std::exception_ptr ex_ptr);
    
        // need to keep track of threads so we can join them
        std::vector< std::thread > workers;
        // target pool size
        std::size_t pool_size;
        // the task queue
        std::queue< std::function<void()> > tasks;
        // queue length limit
        std::size_t max_queue_size = 100000;
        // stop signal
        bool stop = false;
    
        // synchronization
        std::mutex queue_mutex;
        std::condition_variable condition_producers;
        std::condition_variable condition_consumers;
    
        std::mutex in_flight_mutex;
        std::condition_variable in_flight_condition;
        std::atomic<std::size_t> in_flight;
    
        struct handle_in_flight_decrement
        {
            ThreadPool & tp;
    
            handle_in_flight_decrement(ThreadPool & tp_)
                : tp(tp_)
            { }
    
            ~handle_in_flight_decrement()
            {
                std::size_t prev
                    = std::atomic_fetch_sub_explicit(&tp.in_flight,
                        std::size_t(1),
                        std::memory_order_acq_rel);
                if (prev == 1)
                {
                    std::unique_lock<std::mutex> guard(tp.in_flight_mutex);
                    tp.in_flight_condition.notify_all();
                }
            }
        };
    };
    
    // the constructor just launches some amount of workers
    inline ThreadPool::ThreadPool(std::size_t threads)
        : pool_size(threads)
        , in_flight(0)
    {
        std::unique_lock<std::mutex> lock(this->queue_mutex);
        for (std::size_t i = 0; i != threads; ++i)
            start_worker(i, lock);
    }
    
    // add new work item to the pool and block if the queue is full
    template<class F, class... Args>
    auto ThreadPool::enqueue_block(F&& f, Args&&... args) -> std::future<return_type<F, Args...>>
    {
        return enqueue_worker (true, std::forward<F> (f), std::forward<Args> (args)...);
    }
    
    // add new work item to the pool and return future with would_block exception if it is full
    template<class F, class... Args>
    auto ThreadPool::enqueue(F&& f, Args&&... args) -> std::future<return_type<F, Args...>>
    {
        return enqueue_worker (false, std::forward<F> (f), std::forward<Args> (args)...);
    }
    
    template <typename F, typename... Args>
    auto ThreadPool::enqueue_worker(bool block, F&& f, Args&&... args) -> std::future<return_type<F, Args...>>
    {
        auto task = std::make_shared< std::packaged_task<return_type<F, Args...>()> >(
                std::bind(std::forward<F>(f), std::forward<Args>(args)...)
            );
    
        std::future<return_type<F, Args...>> res = task->get_future();
    
        std::unique_lock<std::mutex> lock(queue_mutex);
    
        if (tasks.size () >= max_queue_size)
        {
            if (block)
            {
                // wait for the queue to empty or be stopped
                condition_producers.wait(lock,
                    [this]
                    {
                        return tasks.size () < max_queue_size
                            || stop;
                    });
            }
            else
            {
                return ThreadPool::make_exception_future<return_type<F, Args...>> (
                    std::make_exception_ptr (would_block("queue full")));
            }
        }
    
    
        // don't allow enqueueing after stopping the pool
        if (stop)
            throw std::runtime_error("enqueue on stopped ThreadPool");
    
        tasks.emplace([task](){ (*task)(); });
        std::atomic_fetch_add_explicit(&in_flight,
            std::size_t(1),
            std::memory_order_relaxed);
        condition_consumers.notify_one();
    
        return res;
    }
    
    // the destructor joins all threads
    inline ThreadPool::~ThreadPool()
    {
        std::unique_lock<std::mutex> lock(queue_mutex);
        stop = true;
        pool_size = 0;
        condition_consumers.notify_all();
        condition_producers.notify_all();
        condition_consumers.wait(lock, [this]{ return this->workers.empty(); });
        assert(in_flight == 0);
    }
    
    inline void ThreadPool::wait_until_empty()
    {
        std::unique_lock<std::mutex> lock(this->queue_mutex);
        this->condition_producers.wait(lock,
            [this]{ return this->tasks.empty(); });
    }
    
    inline void ThreadPool::wait_until_nothing_in_flight()
    {
        std::unique_lock<std::mutex> lock(this->in_flight_mutex);
        this->in_flight_condition.wait(lock,
            [this]{ return this->in_flight == 0; });
    }
    
    inline void ThreadPool::set_queue_size_limit(std::size_t limit)
    {
        std::unique_lock<std::mutex> lock(this->queue_mutex);
    
        if (stop)
            return;
    
        std::size_t const old_limit = max_queue_size;
        max_queue_size = (std::max)(limit, std::size_t(1));
        if (old_limit < max_queue_size)
            condition_producers.notify_all();
    }
    
    inline void ThreadPool::set_pool_size(std::size_t limit)
    {
        if (limit < 1)
            limit = 1;
    
        std::unique_lock<std::mutex> lock(this->queue_mutex);
    
        if (stop)
            return;
    
        std::size_t const old_size = pool_size;
        assert(this->workers.size() >= old_size);
    
        pool_size = limit;
        if (pool_size > old_size)
        {
            // create new worker threads
            // it is possible that some of these are still running because
            // they have not stopped yet after a pool size reduction, such
            // workers will just keep running
            for (std::size_t i = old_size; i != pool_size; ++i)
                start_worker(i, lock);
        }
        else if (pool_size < old_size)
            // notify all worker threads to start downsizing
            this->condition_consumers.notify_all();
    }
    
    inline void ThreadPool::start_worker(
        std::size_t worker_number, std::unique_lock<std::mutex> const &lock)
    {
        assert(lock.owns_lock() && lock.mutex() == &this->queue_mutex);
        assert(worker_number <= this->workers.size());
    
        auto worker_func =
            [this, worker_number]
            {
                for(;;)
                {
                    std::function<void()> task;
                    bool notify;
    
                    {
                        std::unique_lock<std::mutex> lock(this->queue_mutex);
                        this->condition_consumers.wait(lock,
                            [this, worker_number]{
                                return this->stop || !this->tasks.empty()
                                    || pool_size < worker_number + 1; });
    
                        // deal with downsizing of thread pool or shutdown
                        if ((this->stop && this->tasks.empty())
                            || (!this->stop && pool_size < worker_number + 1))
                        {
                            // detach this worker, effectively marking it stopped
                            this->workers[worker_number].detach();
                            // downsize the workers vector as much as possible
                            while (this->workers.size() > pool_size
                                 && !this->workers.back().joinable())
                                this->workers.pop_back();
                            // if this is was last worker, notify the destructor
                            if (this->workers.empty())
                                this->condition_consumers.notify_all();
                            return;
                        }
                        else if (!this->tasks.empty())
                        {
                            task = std::move(this->tasks.front());
                            this->tasks.pop();
                            notify = this->tasks.size() + 1 ==  max_queue_size
                                || this->tasks.empty();
                        }
                        else
                            continue;
                    }
    
                    handle_in_flight_decrement guard(*this);
    
                    if (notify)
                    {
                        std::unique_lock<std::mutex> lock(this->queue_mutex);
                        condition_producers.notify_all();
                    }
    
                    task();
                }
            };
    
        if (worker_number < this->workers.size()) {
            std::thread & worker = this->workers[worker_number];
            // start only if not already running
            if (!worker.joinable()) {
                worker = std::thread(worker_func);
            }
        } else
            this->workers.push_back(std::thread(worker_func));
    }
    
    template <typename T>
    inline std::future<T> ThreadPool::make_exception_future (std::exception_ptr ex_ptr)
    {
        std::promise<T> p;
        p.set_exception (ex_ptr);
        return p.get_future ();
    }
    
    } // namespace progschj
    
    #endif // THREAD_POOL_H_7ea1ee6b_4f17_4c09_b76b_3d44e102400c