179 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			179 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
| //===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
 | |
| //
 | |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | |
| // See https://llvm.org/LICENSE.txt for license information.
 | |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | |
| //
 | |
| //===----------------------------------------------------------------------===//
 | |
| 
 | |
| #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
 | |
| #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
 | |
| 
 | |
| #include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h"
 | |
| #include "llvm/Support/Error.h"
 | |
| 
 | |
| #include <atomic>
 | |
| #include <condition_variable>
 | |
| #include <queue>
 | |
| 
 | |
| namespace llvm {
 | |
| 
 | |
| class QueueChannelError : public ErrorInfo<QueueChannelError> {
 | |
| public:
 | |
|   static char ID;
 | |
| };
 | |
| 
 | |
| class QueueChannelClosedError
 | |
|     : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
 | |
| public:
 | |
|   static char ID;
 | |
|   std::error_code convertToErrorCode() const override {
 | |
|     return inconvertibleErrorCode();
 | |
|   }
 | |
| 
 | |
|   void log(raw_ostream &OS) const override {
 | |
|     OS << "Queue closed";
 | |
|   }
 | |
| };
 | |
| 
 | |
| class Queue : public std::queue<char> {
 | |
| public:
 | |
|   using ErrorInjector = std::function<Error()>;
 | |
| 
 | |
|   Queue()
 | |
|     : ReadError([]() { return Error::success(); }),
 | |
|       WriteError([]() { return Error::success(); }) {}
 | |
| 
 | |
|   Queue(const Queue&) = delete;
 | |
|   Queue& operator=(const Queue&) = delete;
 | |
|   Queue(Queue&&) = delete;
 | |
|   Queue& operator=(Queue&&) = delete;
 | |
| 
 | |
|   std::mutex &getMutex() { return M; }
 | |
|   std::condition_variable &getCondVar() { return CV; }
 | |
|   Error checkReadError() { return ReadError(); }
 | |
|   Error checkWriteError() { return WriteError(); }
 | |
|   void setReadError(ErrorInjector NewReadError) {
 | |
|     {
 | |
|       std::lock_guard<std::mutex> Lock(M);
 | |
|       ReadError = std::move(NewReadError);
 | |
|     }
 | |
|     CV.notify_one();
 | |
|   }
 | |
|   void setWriteError(ErrorInjector NewWriteError) {
 | |
|     std::lock_guard<std::mutex> Lock(M);
 | |
|     WriteError = std::move(NewWriteError);
 | |
|   }
 | |
| private:
 | |
|   std::mutex M;
 | |
|   std::condition_variable CV;
 | |
|   std::function<Error()> ReadError, WriteError;
 | |
| };
 | |
| 
 | |
| class QueueChannel : public orc::rpc::RawByteChannel {
 | |
| public:
 | |
|   QueueChannel(std::shared_ptr<Queue> InQueue,
 | |
|                std::shared_ptr<Queue> OutQueue)
 | |
|       : InQueue(InQueue), OutQueue(OutQueue) {}
 | |
| 
 | |
|   QueueChannel(const QueueChannel&) = delete;
 | |
|   QueueChannel& operator=(const QueueChannel&) = delete;
 | |
|   QueueChannel(QueueChannel&&) = delete;
 | |
|   QueueChannel& operator=(QueueChannel&&) = delete;
 | |
| 
 | |
|   template <typename FunctionIdT, typename SequenceIdT>
 | |
|   Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
 | |
|     ++InFlightOutgoingMessages;
 | |
|     return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
 | |
|   }
 | |
| 
 | |
|   Error endSendMessage() {
 | |
|     --InFlightOutgoingMessages;
 | |
|     ++CompletedOutgoingMessages;
 | |
|     return orc::rpc::RawByteChannel::endSendMessage();
 | |
|   }
 | |
| 
 | |
|   template <typename FunctionIdT, typename SequenceNumberT>
 | |
|   Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
 | |
|     ++InFlightIncomingMessages;
 | |
|     return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
 | |
|   }
 | |
| 
 | |
|   Error endReceiveMessage() {
 | |
|     --InFlightIncomingMessages;
 | |
|     ++CompletedIncomingMessages;
 | |
|     return orc::rpc::RawByteChannel::endReceiveMessage();
 | |
|   }
 | |
| 
 | |
|   Error readBytes(char *Dst, unsigned Size) override {
 | |
|     std::unique_lock<std::mutex> Lock(InQueue->getMutex());
 | |
|     while (Size) {
 | |
|       {
 | |
|         Error Err = InQueue->checkReadError();
 | |
|         while (!Err && InQueue->empty()) {
 | |
|           InQueue->getCondVar().wait(Lock);
 | |
|           Err = InQueue->checkReadError();
 | |
|         }
 | |
|         if (Err)
 | |
|           return Err;
 | |
|       }
 | |
|       *Dst++ = InQueue->front();
 | |
|       --Size;
 | |
|       ++NumRead;
 | |
|       InQueue->pop();
 | |
|     }
 | |
|     return Error::success();
 | |
|   }
 | |
| 
 | |
|   Error appendBytes(const char *Src, unsigned Size) override {
 | |
|     std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
 | |
|     while (Size--) {
 | |
|       if (Error Err = OutQueue->checkWriteError())
 | |
|         return Err;
 | |
|       OutQueue->push(*Src++);
 | |
|       ++NumWritten;
 | |
|     }
 | |
|     OutQueue->getCondVar().notify_one();
 | |
|     return Error::success();
 | |
|   }
 | |
| 
 | |
|   Error send() override {
 | |
|     ++SendCalls;
 | |
|     return Error::success();
 | |
|   }
 | |
| 
 | |
|   void close() {
 | |
|     auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
 | |
|     InQueue->setReadError(ChannelClosed);
 | |
|     InQueue->setWriteError(ChannelClosed);
 | |
|     OutQueue->setReadError(ChannelClosed);
 | |
|     OutQueue->setWriteError(ChannelClosed);
 | |
|   }
 | |
| 
 | |
|   uint64_t NumWritten = 0;
 | |
|   uint64_t NumRead = 0;
 | |
|   std::atomic<size_t> InFlightIncomingMessages{0};
 | |
|   std::atomic<size_t> CompletedIncomingMessages{0};
 | |
|   std::atomic<size_t> InFlightOutgoingMessages{0};
 | |
|   std::atomic<size_t> CompletedOutgoingMessages{0};
 | |
|   std::atomic<size_t> SendCalls{0};
 | |
| 
 | |
| private:
 | |
| 
 | |
|   std::shared_ptr<Queue> InQueue;
 | |
|   std::shared_ptr<Queue> OutQueue;
 | |
| };
 | |
| 
 | |
| inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
 | |
| createPairedQueueChannels() {
 | |
|   auto Q1 = std::make_shared<Queue>();
 | |
|   auto Q2 = std::make_shared<Queue>();
 | |
|   auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
 | |
|   auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
 | |
|   return std::make_pair(std::move(C1), std::move(C2));
 | |
| }
 | |
| 
 | |
| }
 | |
| 
 | |
| #endif
 |