Skip to content

多线程锁key处理

Published: at 02:38 AM | 11 min read

目的是要保证多个线程处理同一个key的消息时是线程安全的,同时是顺序的。

  1. 收到网络过来的消息,将消息缓存到total msg cache里
  2. 解析业务消息找到每条数据的key(可能一条或多条)存储在set中
  3. 将所有key进行hash,对最大线程数进行取余,获取所有线程索引std::set indexList
  4. 如果消息没有index即没有key,就轮询分发给otherThread处理
  5. 有index的消息,根据index分发给线程,如果一个消息有多个index,对应的一个消息会分发给多个线程(keyThread)
  6. 每个keyThread不断的从线程队列里取出消息ID来处理(线程队列只保留消息ID)
  7. 根据消息ID到total msg cache里去查找消息
  8. 如果没有找到,说明已经被其他线程消费掉了(因为一个消息有多个key的话会分配给多个thread)
  9. 如果找到了,递减消息的索引数,索引数大于0就等待,索引数为0就处理消息然后从total msg cache中删除消息(一个消息只能被处理一次),最后通知其他线程
  10. 其他线程收到通知后再次去cache查找消息重复步骤7,直到消息被处理掉
  11. 最后每个线程不断重复从线程队列取消息

关键点: 如果一个消息被分配给了n个线程那它的索引数就是n,为了保证有相同key的消息的处理顺序,当找到了索引数大于0的消息时当前线程必须等待,要保证其他线程中当前消息之前的消息都被处理完也就是说索引递减为0了才能处理。

如果一个消息只有一个key,对key进行hash将消息分配到指定的线程就能保证相同的key不会被多个线程处理,同时是顺序的

如果一个消息有多个key,递减索引数以及等待索引数为0,能保证消息的顺序执行

keysync 代码:

// .h
typedef std::shared_ptr<Message> MessagePtr;

class Service;
// 消息简单包装
class MessagePack {
public:
    MessagePack();
    MessagePack(Service *service, const MessagePtr &msg, bool ignoreLog, const std::set<int> &indexs);
    MessagePack(const MessagePack&);
    MessagePack& operator=(const MessagePack&);

    // 是否打印日志
    bool ignoreLog() const;
    // 索引列表
    const std::set<int>& indexList() const;
    // 是否是多索引
    bool isMultiIndex() const;
    // 递减索引数并返回递减后的索引数
    int indexCountSub();
    // 当前索引数
    int indexCount() const;
    // 原始消息
    MessagePtr msg() const;
    // 消息ID
    const std::string& id() const;
    // 处理消息
    void handle();
    // 收到构建时的时间
    time_t msStart() const;

private:
    bool ignoreLog_;
    MessagePtr msg_;
    std::set<int> indexs_;
    std::string id_;
    boost::atomic_int  indexCount_;
    Service *service_;
    time_t msStart_;
};

//////////////////////////////////////////////////////////////////////////
// 网络消息缓存
class NetMessageCache {
public:
    NetMessageCache();
    // 将消息新增到缓存
    void add(const MessagePack &msg);
    // 通过消息ID从缓存获取消息,并且可能会等待
    bool getAndWait(const std::string &id, MessagePack &msg);
    // 从缓存中删除消息并通知等待方
    void delAndNotify(const std::string &id);
    // 打印缓存日志
    void printLog() const;

private:
    boost::mutex mutex_;
    boost::condition_variable cond_;
    std::map<std::string, MessagePack> allMsg_;

private:
    NetMessageCache(const NetMessageCache&);
    NetMessageCache& operator=(const NetMessageCache&);
};

//////////////////////////////////////////////////////////////////////////
// 线程工作者
// 两种模式:第一种是通过msgCache处理消息ID(保证相同key消息顺序执行),第二种直接处理消息(效率高)
// key线程使用第一种模式,other线程使用第二种模式
// 一个实例只能使用一种模式,不能混用
class Worker {
public:
    Worker(int i, NetMessageCache *msgCache);
    int index() const;
    void addMsgId(const std::string &id);
    void addMsgPack(const MessagePack &msg);
    void join();
    boost::thread::id threadId() const;

    void onHandle();
    void handleMsgIdQueue();
    void handleMsgPackQueue();

private:
    Worker(const Worker&);
    Worker& operator=(const Worker&);

    int index_;
    bool running_;
    std::queue<std::string> msgIdQueue_;
    std::queue<MessagePack> msgPackQueue_;
    boost::mutex mutex_;
    boost::condition_variable cond_;
    NetMessageCache* msgCache_;
    boost::thread thread_;
};

//////////////////////////////////////////////////////////////////////////
// 工作调度中心
// 不同的消息分发给不同的worker
class WorkerCenter
{
public:
    static WorkerCenter* GetInstance();
    // 初始化线程池
    void init(int keyThreadCount, int otherThreadCount);
    // 将消息分发到不同的线程去处理
    void dispatchMessage(const MessagePack &msg);
    // key线程数
    int keyThreadCount() const { return keyWorkerList_.size(); }
    // other线程数
    int otherThreadCount() const { return otherWorkerList_.size(); }

private:
    WorkerCenter();
    ~WorkerCenter();
    WorkerCenter(const WorkerCenter&);
    WorkerCenter& operator=(const WorkerCenter&);

    std::vector<std::shared_ptr<Worker> > keyWorkerList_;
    std::vector<std::shared_ptr<Worker> > otherWorkerList_;
    std::unique_ptr<NetMessageCache> msgCache_;
    boost::atomic_int otherWorkerIndex_;
};
// .cpp
using namespace boost::gregorian;
using namespace boost::posix_time;
boost::hash<std::string> str_hash;

MessagePack::MessagePack()
    : ignoreLog_(false), service_(NULL)
{
    indexCount_.store(0);
    msStart_ = GetCurrentTimeMilliSec();
}

MessagePack::MessagePack(Service *service, const MessagePtr &msg, bool ignoreLog, const std::set<int> &indexs)
    : ignoreLog_(ignoreLog), service_(service), msg_(msg), indexs_(indexs)
{
    id_ = msg_->GetMessageID();
    indexCount_.store(indexs_.size());
    msStart_ = GetCurrentTimeMilliSec();
}

MessagePack::MessagePack(const MessagePack& other)
{
    ignoreLog_ = other.ignoreLog_;
    msg_ = other.msg_;
    indexs_ = other.indexs_;
    id_ = other.id_;
    indexCount_.store(other.indexCount_);
    service_ = other.service_;
    msStart_ = other.msStart_;
}

MessagePack& MessagePack::operator = (const MessagePack& other)
{
    if (this == &other) {
        return *this;
    }
    ignoreLog_ = other.ignoreLog_;
    msg_ = other.msg_;
    indexs_ = other.indexs_;
    id_ = other.id_;
    indexCount_.store(other.indexCount_);
    service_ = other.service_;
    msStart_ = other.msStart_;
    return *this;
}

bool MessagePack::ignoreLog() const
{
    return ignoreLog_;
}

const std::set<int>& MessagePack::indexList() const
{
    return indexs_;
}

bool MessagePack::isMultiIndex() const
{
    return indexs_.size() > 1;
}

int MessagePack::indexCountSub()
{
    return --indexCount_;
}

int MessagePack::indexCount() const
{
    return indexCount_.load();
}

MessagePtr MessagePack::msg() const
{
    return msg_;
}

const std::string& MessagePack::id() const
{
    return id_;
}

void MessagePack::handle()
{
    if (service_) {
        service_->processReqMsg(msg_.get());
        if (!ignoreLog_) {
            LOGGER_INFO("message pack handle finished, type:" << msg_->GetType().c_str() << ", id:" << id_ <<", cost:" << (GetCurrentTimeMilliSec() - msStart_) << "ms");
        }
    } else {
        LOGGER_ERROR("SERVICE ERROR");
    }
}

time_t MessagePack::msStart() const
{
    return msStart_;
}
//////////////////////////////////////////////////////////////////////////

NetMessageCache::NetMessageCache()
{
}

void NetMessageCache::add(const MessagePack &msg)
{
    boost::unique_lock<boost::mutex> lock(mutex_);
    allMsg_[msg.id()] = msg;
    if (allMsg_.size() > 1000) {
        LOGGER_WARN("NetMessageCache::add service busy, msgCacheSize:" << allMsg_.size());
    }
}

bool NetMessageCache::getAndWait(const std::string &id, MessagePack &msg)
{
    const int WAIT_TIMEOUT_SECOND = 60;
    boost::unique_lock<boost::mutex> lock(mutex_);
    bool isFirst = true;
    bool isExist = true;
    while (isExist) {
        std::map<std::string, MessagePack>::iterator iter = allMsg_.find(id);
        if (iter == allMsg_.end()) {
            isExist = false;
            continue;
        }

        // 不是多index,直接返回然后删除
        if (!iter->second.isMultiIndex()) {
            msg = iter->second;
            allMsg_.erase(iter);
            return true;
        }

        int count = iter->second.indexCount();
        if (isFirst) {
            // 只考虑第一次进入时,因为这里会执行多次,会造成index count多次递减
            // 如果index count递减后为0说明其他线程前面的消息已经执行完成了
            // 保证有相同key的消息是顺序执行的
            isFirst = false;
            count = iter->second.indexCountSub();
            if (count == 0) {
                msg = iter->second;
                return true;
            }
        }
        // count大于0时要等待其他线程执行完成
        time_t waitStartTime = GetCurrentTimeMilliSec();
        cond_.timed_wait(lock, boost::get_system_time() + boost::posix_time::seconds(WAIT_TIMEOUT_SECOND));
        if (GetCurrentTimeMilliSec() - waitStartTime > (WAIT_TIMEOUT_SECOND - 1) * 1000) {
            LOGGER_ERROR("NetMessageCache wait timeout " << WAIT_TIMEOUT_SECOND << ", id:" << id << ", count:" << count);
            printLog();
        }
    }
    return false;
}

void NetMessageCache::delAndNotify(const std::string &id)
{
    boost::unique_lock<boost::mutex> lock(mutex_);
    std::map<std::string, MessagePack>::iterator iter = allMsg_.find(id);
    if (iter != allMsg_.end()) {
        // 能够删除的消息key count应该都是0
        if (iter->second.indexCount() != 0) {
            LOGGER_ERROR("NetMessageCache ERROR ref count:" << iter->second.indexCount());
        }
        allMsg_.erase(iter);
        cond_.notify_all();
    }
}

void NetMessageCache::printLog() const
{
    std::map<std::string, MessagePack>::const_iterator iter = allMsg_.cbegin();
    while (iter != allMsg_.cend()) {
        const MessagePack &msg = iter->second;
        LOGGER_INFO("NetMessageCache msgid:" << iter->first << ", indexCount:" << msg.indexList().size()
            << ", curIndexCount:" << msg.indexCount() << ", cost:" << (GetCurrentTimeMilliSec() - msg.msStart()) << "ms");
        ++iter;
    }
    LOGGER_INFO("-----------------------------------------");
}
//////////////////////////////////////////////////////////////////////////

Worker::Worker(int i, NetMessageCache *msgCache)
    : index_(i), running_(true), msgCache_(msgCache)
    , thread_(boost::bind(&Worker::onHandle, this))
{
}

int Worker::index() const
{
    return index_;
}

void Worker::addMsgId(const std::string &id)
{
    boost::unique_lock<boost::mutex> lock(mutex_);
    msgIdQueue_.push(id);
    int count = msgIdQueue_.size();
    cond_.notify_one();
    if (count > 100) {
        LOGGER_WARN("Worker::addMsgId thread busy index:" << index_ << ", threadId:" << threadId() << ", queueSize:" << msgIdQueue_.size());
    }
}

void Worker::addMsgPack(const MessagePack &msg)
{
    boost::unique_lock<boost::mutex> lock(mutex_);
    msgPackQueue_.push(msg);
    int count = msgPackQueue_.size();
    cond_.notify_one();
    if (count > 100) {
        LOGGER_WARN("Worker::addMsgPack thread busy index:" << index_ << ", threadId:" << threadId() << ", queueSize:" << msgPackQueue_.size());
    }
}

void Worker::join() {
    running_ = false;
    if (thread_.joinable()) {
        // 线程可能被阻塞,join会一直等待
        thread_.join();
    }
}

boost::thread::id Worker::threadId() const
{
    return thread_.get_id();
}

void Worker::onHandle()
{
    while (running_) {
        if (msgCache_) {
            handleMsgIdQueue();
        } else {
            handleMsgPackQueue();
        }
    }
}

void Worker::handleMsgIdQueue()
{
    std::string msgId;
    {
        boost::unique_lock<boost::mutex> lock(mutex_);
        while (msgIdQueue_.empty()) {
            cond_.wait(lock);
        }
        msgId = msgIdQueue_.front();
        msgIdQueue_.pop();
    }

    if (!msgId.empty()) {
        MessagePack msg;
        // 消息没有找到被其他线程消费掉了
        if (msgCache_->getAndWait(msgId, msg)) {
            // 消息处理
            msg.handle();
            // 从缓存中删掉消息,通知其他线程消息被消费掉了
            if (msg.isMultiIndex()) {
                msgCache_->delAndNotify(msgId);
            }
        }
    }
}

void Worker::handleMsgPackQueue()
{
    MessagePack msgPack;
    {
        boost::unique_lock<boost::mutex> lock(mutex_);
        while (msgPackQueue_.empty()) {
            cond_.wait(lock);
        }
        msgPack = msgPackQueue_.front();
        msgPackQueue_.pop();
    }

    if (msgPack.msg()) {
        msgPack.handle();
    }
}
//////////////////////////////////////////////////////////////////////////

WorkerCenter::WorkerCenter()
    : msgCache_(new NetMessageCache()), otherWorkerIndex_(0)
{
}

WorkerCenter::~WorkerCenter()
{
    for (std::vector<std::shared_ptr<Worker> >::iterator iter = keyWorkerList_.begin(); iter != keyWorkerList_.end(); ++iter) {
        (*iter)->join();
    }
    for (std::vector<std::shared_ptr<Worker> >::iterator iter = otherWorkerList_.begin(); iter != otherWorkerList_.end(); ++iter) {
        (*iter)->join();
    }
}

WorkerCenter* WorkerCenter::GetInstance()
{
    static WorkerCenter s_inst;
    return &s_inst;
}

void WorkerCenter::init(int keyThreadCount, int otherThreadCount)
{
    LOGGER_INFO("WorkerCenter init, keyThreadCount:" << keyThreadCount << ", otherThreadCount:" << otherThreadCount);
    for (int i = 0; i < keyThreadCount; i++) {
        std::shared_ptr<Worker> worker = std::make_shared<Worker>(i, msgCache_.get());
        keyWorkerList_.push_back(worker);
        LOGGER_INFO("WorkerCenter init, key thread index:" << worker->index() << ", id:" << worker->threadId());
    }

    NetMessageCache *emptyMsgCache = NULL;
    for (int i = 0; i < otherThreadCount; i++) {
        std::shared_ptr<Worker> worker = std::make_shared<Worker>(i, emptyMsgCache);
        otherWorkerList_.push_back(worker);
        LOGGER_INFO("WorkerCenter init, other thread index:" << worker->index() << ", id:" << worker->threadId());
    }
}

void WorkerCenter::dispatchMessage(const MessagePack &msg)
{
    const std::set<int>& indexList = msg.indexList();
    if (indexList.empty()) {
        // 无index的消息
        // 线程轮询处理
        int index = (otherWorkerIndex_++) % otherWorkerList_.size();
        otherWorkerList_[index]->addMsgPack(msg);
        if (!msg.ignoreLog()) {
            LOGGER_INFO("WorkerCenter::dispatchMessage other thread id:" << msg.id() << ", index:" << index);
        }
    } else {
        // 有index的消息
        // 将消息分发到index对应的线程
        msgCache_->add(msg);
        for (std::set<int>::iterator iter = indexList.begin(); iter != indexList.end(); ++iter) {
            keyWorkerList_[*iter]->addMsgId(msg.id());
        }
        if (!msg.ignoreLog()) {
            LOGGER_INFO("WorkerCenter::dispatchMessage key thread id:" << msg.id() << ", index size:" << indexList.size());
        }
    }
}
//////////////////////////////////////////////////////////////////////////

外部调用:WorkerCenter::init, WorkerCenter::dispatchMessage