diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f53b4c..3998dc2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,5 +3,9 @@ cmake_minimum_required (VERSION 3.8) project ("vsock-bridge") set(CMAKE_CXX_FLAGS_DEBUG "-ggdb") +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +enable_testing () add_subdirectory ("vsock-bridge") diff --git a/README.md b/README.md index 6ce0230..5b51375 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Vsock Proxy -Vsock Proxy to proxy TCP connection to vsock and vise versa. +Vsock Proxy to proxy TCP connection to vsock and vice versa. This is intended for UID2 traffic forwarding between host and AWS Nitro Enclaves. @@ -11,6 +11,7 @@ mkdir uid2-aws-enclave-vsockproxy/build cd uid2-aws-enclave-vsockproxy/build cmake .. -DCMAKE_BUILD_TYPE=RelWithDebInfo make +make test ``` ## How to use @@ -24,12 +25,33 @@ http-service: service: direct listen: tcp://0.0.0.0:80 connect: vsock://42:8080 + +sockx-proxy: + service: direct + listen: vsock://3:3305 + connect: tcp://127.0.0.1:3305 + +tcp-to-tcp: + service: direct + listen: tcp://127.0.0.1:4000 + connect: tcp://10.10.10.10:4001 ``` -Start vsockpx +This configuration file instructs the proxy to: + - listen on all IPv4 addresses on TCP port 80 and forward connections to vsock address 42:8080; + - listen on vsock address 3:3305 and forward connections to localhost (IPv4) TCP port 3305; + - listen on localhost (IPv4) TCP port 4000 and forward connections to 10.10.10.10 TCP port 4001. + +Start vsock-bridge: ``` -./vsockpx --config config.notyaml +./vsock-bridge --config config.notyaml ``` -Traffic hitting host:80 port will be forwarded to vsock address 42:8080. +Run `./vsock-bridge -h` to get details for other supported command line options. + +## Logging + +In daemon mode the proxy logs to system (with ident `vsockpx`). In frontend mode logs go to stdout. + +The log level can be configured through command line option `--log-level`. diff --git a/vsock-bridge/CMakeLists.txt b/vsock-bridge/CMakeLists.txt index e606ae6..45b24ef 100644 --- a/vsock-bridge/CMakeLists.txt +++ b/vsock-bridge/CMakeLists.txt @@ -4,6 +4,4 @@ cmake_minimum_required (VERSION 3.8) add_subdirectory (src) - -enable_testing () add_subdirectory (test) \ No newline at end of file diff --git a/vsock-bridge/include/buffer.h b/vsock-bridge/include/buffer.h index 02d1aa0..d4151d8 100644 --- a/vsock-bridge/include/buffer.h +++ b/vsock-bridge/include/buffer.h @@ -1,222 +1,64 @@ #pragma once -#include "logger.h" - -#include +#include #include #include -#include -#include -#include -#include - -#include namespace vsockio { - struct MemoryBlock - { - MemoryBlock(int size, class MemoryArena* region) - : _startPtr(std::make_unique(size)), _region(region) {} - - uint8_t* offset(int x) const - { - return _startPtr.get() + x; - } - - std::unique_ptr _startPtr; - class MemoryArena* _region; - }; - - struct MemoryArena - { - std::vector _blocks; - std::list _handles; - uint32_t _blockSizeInBytes = 0; - bool _initialized = false; - - MemoryArena() = default; - - void init(int blockSize, int numBlocks) - { - if (_initialized) throw; - - Logger::instance->Log(Logger::INFO, "Thread-local memory arena init: blockSize=", blockSize, ", numBlocks=", numBlocks); - - _blockSizeInBytes = blockSize; - - for (int i = 0; i < numBlocks; i++) - { - _blocks.emplace_back(blockSize, this); - } - - for (int i = 0; i < numBlocks; i++) - { - _handles.push_back(&_blocks[i]); - } - - _initialized = true; - } - - MemoryBlock* get() - { - if (!_handles.empty()) - { - auto mb = _handles.front(); - _handles.pop_front(); - return mb; - } - else - { - return new MemoryBlock(_blockSizeInBytes, nullptr); - } - } - - void put(MemoryBlock* mb) - { - if (mb->_region == this) - { - _handles.push_front(mb); - } - else if (mb->_region == nullptr) - { - delete mb; - } - else - { - throw; - } - } - - int blockSize() const { return _blockSizeInBytes; } - }; - struct Buffer { - constexpr static int MAX_PAGES = 20; - int _pageCount; - int _cursor; - int _size; - int _pageSize; - MemoryBlock* _pages[MAX_PAGES]; - MemoryArena* _arena; - - explicit Buffer(MemoryArena* arena) : _arena(arena), _pageCount{ 0 }, _cursor{ 0 }, _size{ 0 }, _pageSize(arena->blockSize()) {} - - Buffer(Buffer&& b) : _arena(b._arena), _pageCount(b._pageCount), _cursor(b._cursor), _size(b._size), _pageSize(b._arena->blockSize()) - { - for (int i = 0; i < _pageCount; i++) - { - _pages[i] = b._pages[i]; - } - b._pageCount = 0; // prevent _pages being destructed by old object - } + // Use the default minimum socket send buffer size on Linux. + static constexpr int BUFFER_SIZE = 4096; + std::array _data; + std::uint8_t* _head = _data.data(); + std::uint8_t* _tail = _data.data(); - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; + std::uint8_t* head() const + { + return _head; + } - ~Buffer() + std::uint8_t* tail() const { - for (int i = 0; i < _pageCount; i++) - { - _arena->put(_pages[i]); - } + return _tail; } - uint8_t* tail() const - { - return offset(_size); - } + bool hasRemainingCapacity() const + { + return _tail < _data.end(); + } int remainingCapacity() const { - return capacity() - _size; + return _data.end() - _tail; } - void produce(int size) - { - _size += size; - } + int remainingDataSize() const + { + return _tail - _head; + } - bool ensureCapacity() - { - return remainingCapacity() > 0 || tryNewPage(); - } - - uint8_t* head() const - { - return offset(_cursor); - } - - int headLimit() const + void produce(int size) { - return std::min(pageLimit(_cursor), _size - _cursor); + assert(remainingCapacity() >= size); + _tail += size; } void consume(int size) { - _cursor += size; - } - - bool tryNewPage() - { - if (_pageCount >= MAX_PAGES) return false; - _pages[_pageCount++] = _arena->get(); - return true; - } - - uint8_t* offset(int x) const - { - return _pages[x / _pageSize]->offset(x % _pageSize); - } - - int capacity() const - { - return _pageCount * _pageSize; - } - - int pageLimit(int x) const - { - return _pageSize - (x % _pageSize); - } - - int cursor() const - { - return _cursor; + assert(remainingDataSize() >= size); + _head += size; } - int size() const - { - return _size; - } - - bool empty() const - { - return _size <= 0; - } + void reset() + { + _head = _tail = _data.data(); + } bool consumed() const { - return _cursor >= _size; - } - }; - - struct BufferManager - { - thread_local static MemoryArena* arena; - - static std::unique_ptr getBuffer() - { - auto b = std::make_unique(arena); - b->tryNewPage(); - return b; - } - - static std::unique_ptr getEmptyBuffer() - { - return std::make_unique(arena); + return _head >= _tail; } }; - - -} \ No newline at end of file +} diff --git a/vsock-bridge/include/channel.h b/vsock-bridge/include/channel.h index 62a17ca..884ff17 100644 --- a/vsock-bridge/include/channel.h +++ b/vsock-bridge/include/channel.h @@ -5,17 +5,22 @@ #include "socket.h" #include "threading.h" +#include #include namespace vsockio { + struct DirectChannel; + class IOThread; + struct ChannelHandle { - int channelId; - int fd; + DirectChannel* _channel; + int _id; + int _fd; - ChannelHandle(int channelId, int fd) - : channelId(channelId), fd(fd) {} + ChannelHandle(DirectChannel* channel, int id, int fd) + : _channel(channel), _id(id), _fd(fd) {} }; struct DirectChannel @@ -23,63 +28,41 @@ namespace vsockio using TAction = std::function; int _id; - BlockingQueue* _taskQueue; std::unique_ptr _a; std::unique_ptr _b; ChannelHandle _ha; ChannelHandle _hb; - DirectChannel(int id, std::unique_ptr a, std::unique_ptr b, BlockingQueue* taskQueue) + DirectChannel(int id, std::unique_ptr a, std::unique_ptr b) : _id(id) , _a(std::move(a)) , _b(std::move(b)) - , _ha(id, _a->fd()) - , _hb(id, _b->fd()) - , _taskQueue(taskQueue) + , _ha(this, _id, _a->fd()) + , _hb(this, _id, _b->fd()) { _a->setPeer(_b.get()); _b->setPeer(_a.get()); } - void handle(int fd, int evt) - { - Socket* s = _a->fd() == fd ? _a.get() : (_b->fd() == fd ? _b.get() : nullptr); - if (s == nullptr) - { - Logger::instance->Log(Logger::WARNING, "error in channel.handle: `id=", _id,"`, `fd=", fd, "` does not belong to this channel"); - return; - } - - if (evt & IOEvent::Error) - { - Logger::instance->Log(Logger::DEBUG, "poll error for fd=", fd); - evt |= IOEvent::InputReady; - evt |= IOEvent::OutputReady; - } + void performIO(); - if (evt & IOEvent::InputReady) - { - s->onIoEvent(); - _taskQueue->enqueue([=] { s->onInputReady(); }); - } - - if (evt & IOEvent::OutputReady) - { - s->onIoEvent(); - _taskQueue->enqueue([=] { s->onOutputReady(); }); - } - } - - void terminate() - { - _taskQueue->enqueue([this] { delete this; }); - } + bool canReadWriteMore() const + { + return _a->canReadWriteMore() || _b->canReadWriteMore(); + } bool canBeTerminated() const { - return _a->closed() && _b->closed() && _a->ioEventCount() == 0 && _b->ioEventCount() == 0; + return _a->closed() && _b->closed(); } + + Socket& getSocket(int fd) const + { + if (fd == _a->fd()) return *_a; + if (fd == _b->fd()) return *_b; + throw std::runtime_error("unexpected fd for channel"); + } }; } \ No newline at end of file diff --git a/vsock-bridge/include/config.h b/vsock-bridge/include/config.h index 337cf50..a8cd0e3 100644 --- a/vsock-bridge/include/config.h +++ b/vsock-bridge/include/config.h @@ -6,40 +6,35 @@ namespace vsockproxy { - enum class service_type : uint8_t + enum class ServiceType : uint8_t { UNKNOWN = 0, - SOCKS_PROXY, - FILE, DIRECT_PROXY, }; - enum class endpoint_scheme : uint8_t + enum class EndpointScheme : uint8_t { UNKNOWN = 0, VSOCK, TCP4, }; - struct endpoint + struct EndpointConfig { - endpoint_scheme scheme; - std::string address; - uint16_t port; + EndpointScheme _scheme = EndpointScheme::UNKNOWN; + std::string _address; + uint16_t _port = 0; }; - struct service_description + struct ServiceDescription { - std::string name; - service_type type; - endpoint listen_ep; - endpoint connect_ep; - std::vector> mapping; + std::string _name; + ServiceType _type = ServiceType::UNKNOWN; + EndpointConfig _listenEndpoint; + EndpointConfig _connectEndpoint; }; - uint16_t try_str2short(std::string s, uint16_t default_value); + std::vector loadConfig(const std::string& filepath); - std::vector load_config(std::string filepath); - - std::string describe(service_description& sd); + std::string describe(const ServiceDescription& sd); } \ No newline at end of file diff --git a/vsock-bridge/include/dispatcher.h b/vsock-bridge/include/dispatcher.h index 7ef0868..0bf361b 100644 --- a/vsock-bridge/include/dispatcher.h +++ b/vsock-bridge/include/dispatcher.h @@ -1,236 +1,22 @@ #pragma once -#include "channel.h" -#include "logger.h" -#include "poller.h" - -#include -#include -#include -#include +#include "iothread.h" namespace vsockio { - struct ChannelNode - { - int _id; - std::unique_ptr _channel; - - explicit ChannelNode(int id) : _id(id) {} - - void reset() - { - _channel.reset(); - } - - bool inUse() const - { - return !!_channel; - } - }; - - class ChannelNodePool - { - public: - ChannelNodePool() = default; - - ChannelNodePool(const ChannelNodePool&) = delete; - ChannelNodePool& operator=(const ChannelNodePool&) = delete; - - ~ChannelNodePool() - { - for (auto* node : _freeList) - { - delete node; - } - } - - struct ChannelNodeDeleter - { - ChannelNodePool* _pool; - - void operator()(ChannelNode* node) - { - _pool->releaseNode(node); - } - }; - - using ChannelNodePtr = std::unique_ptr; - - ChannelNodePtr getFreeNode() { - const ChannelNodeDeleter deleter{this}; - - if (_freeList.empty()) - { - return ChannelNodePtr(new ChannelNode(_nextNodeId++), deleter); - } - - auto* node = _freeList.front(); - _freeList.pop_front(); - return ChannelNodePtr(node, deleter); - } - - void releaseNode(ChannelNode* node) - { - if (node == nullptr) return; - node->reset(); - _freeList.push_front(node); - } - - private: - int _nextNodeId = 0; - std::forward_list _freeList; - }; - - struct Dispatcher - { - Poller* _poller; - std::vector _events; - ChannelNodePool _idman; - std::unordered_map _channels; - BlockingQueue> _tasksToRun; - - int maxNewConnectionPerLoop = 20; - int scanAndCleanInterval = 20; - int _currentGen = 0; - - int _name; - - Dispatcher(Poller* poller) : Dispatcher(0, poller) {} - - Dispatcher(int name, Poller* poller) : _name(name), _poller(poller), _events(poller->maxEventsPerPoll()) {} - - int name() const - { - return _name; - } - - void postAddChannel(std::unique_ptr&& ap, std::unique_ptr(bp)) - { - // Dispatcher::taskloop manages the channel map attached to the dispatcher - // connectToPeer modifies the map so we request taskloop thread to run it - runOnTaskLoop([this, ap = std::move(ap), bp = std::move(bp)]() mutable { addChannel(std::move(ap), std::move(bp)); }); - } - - ChannelNode* addChannel(std::unique_ptr ap, std::unique_ptr bp) - { - ChannelNodePool::ChannelNodePtr node = _idman.getFreeNode(); - BlockingQueue* taskQueue = ThreadPool::getTaskQueue(node->_id); - - Logger::instance->Log(Logger::DEBUG, "creating channel id=", node->_id, ", a.fd=", ap->fd(), ", b.fd=", bp->fd()); - node->_channel = std::make_unique(node->_id, std::move(ap), std::move(bp), taskQueue); - - const auto& c = *node->_channel; - c._a->setPoller(_poller); - c._b->setPoller(_poller); - if (!_poller->add(c._a->fd(), (void*)&c._ha, IOEvent::InputReady | IOEvent::OutputReady) || - !_poller->add(c._b->fd(), (void*)&c._hb, IOEvent::InputReady | IOEvent::OutputReady)) - { - return nullptr; - } - - auto* const n = node.get(); - _channels[n->_id] = std::move(node); - return n; - } - - template - void runOnTaskLoop(T&& action) - { - auto wrapper = std::make_shared(std::forward(action)); - _tasksToRun.enqueue([wrapper] { (*wrapper)(); }); - } - - void run() - { - Logger::instance->Log(Logger::DEBUG, "dispatcher ", name(), " started"); - for (;;) - { - taskloop(); - } - } - - void taskloop() - { - // handle events on existing channels - poll(); - - // complete new channels - processQueuedTasks(); - - // collect terminated channels - cleanup(); - } - - void poll() - { - const int eventCount = _poller->poll(_events.data(), getTimeout()); - if (eventCount == -1) { - Logger::instance->Log(Logger::CRITICAL, "Poller returns error."); - return; - } - - for (int i = 0; i < eventCount; i++) { - auto *handle = static_cast(_events[i].data); - auto it = _channels.find(handle->channelId); - if (it == _channels.end() || !it->second->inUse()) { - Logger::instance->Log(Logger::WARNING, "Channel ID ", handle->channelId, " does not exist."); - continue; - } - auto &channel = *it->second->_channel; - channel.handle(handle->fd, _events[i].ioFlags); - } - } - - void processQueuedTasks() - { - for (int i = 0; i < maxNewConnectionPerLoop; i++) { - // must check task count first, since we don't wanna block here - if (!_tasksToRun.empty()) { - auto action = _tasksToRun.dequeue(); - action(); - } else { - break; - } - } - } - - void cleanup() - { - if (_currentGen >= scanAndCleanInterval) - { - for (auto it = _channels.begin(); it != _channels.end(); ) - { - auto* node = it->second.get(); - if (!node->inUse() || node->_channel->canBeTerminated()) - { - Logger::instance->Log(Logger::DEBUG, "destroying channel id=", it->first); - // any resources allocated on channel thread must be freed there - if (node->inUse()) - { - node->_channel.release()->terminate(); - } + class Dispatcher + { + public: + explicit Dispatcher(const IOThreadPool& threadPool) : _threadPool(threadPool) {} - it = _channels.erase(it); - } - else - { - ++it; - } - } - _currentGen = 0; - } - _currentGen++; - } + void addChannel(std::unique_ptr&& ap, std::unique_ptr&& bp) + { + _threadPool.addChannel(std::move(ap), std::move(bp)); + } - int getTimeout() const - { - const bool hasPendingTask = - (_currentGen >= scanAndCleanInterval) || - (!_tasksToRun.empty()); + private: + const IOThreadPool& _threadPool; + }; - return hasPendingTask ? 0 : 16; - } - }; } \ No newline at end of file diff --git a/vsock-bridge/include/endpoint.h b/vsock-bridge/include/endpoint.h index d4ea831..5b6ac52 100644 --- a/vsock-bridge/include/endpoint.h +++ b/vsock-bridge/include/endpoint.h @@ -12,17 +12,17 @@ namespace vsockio { struct Endpoint { - virtual int getSocket() = 0; - virtual std::pair getAddress() = 0; + virtual ~Endpoint() = default; + virtual int getSocket() const = 0; + virtual std::pair getAddress() const = 0; virtual std::pair getWritableAddress() = 0; - virtual std::string describe() = 0; - virtual std::unique_ptr clone() = 0; - virtual ~Endpoint() {} + virtual std::string describe() const = 0; + virtual std::unique_ptr clone() const = 0; }; struct TCP4Endpoint : public Endpoint { - TCP4Endpoint(std::string ip, int port) : _ipAddress(ip), _port(port) + TCP4Endpoint(const std::string& ip, int port) : _ipAddress(ip), _port(port) { memset(&_saddr, 0, sizeof(_saddr)); _saddr.sin_family = AF_INET; @@ -30,12 +30,12 @@ namespace vsockio _saddr.sin_addr.s_addr = address(); } - int getSocket() override + int getSocket() const override { return socket(AF_INET, SOCK_STREAM, 0); } - std::pair getAddress() override + std::pair getAddress() const override { return std::make_pair((sockaddr*)&_saddr, sizeof(_saddr)); } @@ -46,14 +46,14 @@ namespace vsockio return std::make_pair((sockaddr*)&_saddr, sizeof(_saddr)); } - std::string describe() override + std::string describe() const override { char buf[20]; inet_ntop(AF_INET, &_saddr.sin_addr.s_addr, buf, sizeof(buf)); return "tcp4://" + std::string(buf) + ":" + std::to_string(ntohs(_saddr.sin_port)); } - std::unique_ptr clone() override + std::unique_ptr clone() const override { return std::unique_ptr(new TCP4Endpoint(_ipAddress, _port)); } @@ -80,12 +80,12 @@ namespace vsockio _saddr.svm_port = _port; // in host byte order } - int getSocket() override + int getSocket() const override { return socket(AF_VSOCK, SOCK_STREAM, 0); } - std::pair getAddress() override + std::pair getAddress() const override { return std::make_pair((sockaddr*)&_saddr, sizeof(_saddr)); } @@ -96,12 +96,12 @@ namespace vsockio return std::make_pair((sockaddr*)&_saddr, sizeof(_saddr)); } - std::string describe() override + std::string describe() const override { return "vsock://" + std::to_string(_cid) + ":" + std::to_string(_port); } - std::unique_ptr clone() override + std::unique_ptr clone() const override { return std::unique_ptr(new VSockEndpoint(_cid, 0)); } diff --git a/vsock-bridge/include/epoll_poller.h b/vsock-bridge/include/epoll_poller.h index e51ec5c..e75eb7c 100644 --- a/vsock-bridge/include/epoll_poller.h +++ b/vsock-bridge/include/epoll_poller.h @@ -27,12 +27,12 @@ namespace vsockio } } - bool add(int fd, void* handler, uint32_t events) override + bool add(int fd, void* handler) override { epoll_event ev; memset(&ev, 0, sizeof(epoll_event)); ev.data.ptr = handler; - ev.events = vsb2epoll(events); + ev.events = EPOLLET | EPOLLIN | EPOLLOUT | EPOLLRDHUP; if (epoll_ctl(_epollFd, EPOLL_CTL_ADD, fd, &ev) != 0) { const int err = errno; @@ -43,22 +43,6 @@ namespace vsockio return true; } - bool update(int fd, void* handler, uint32_t events) override - { - epoll_event ev; - memset(&ev, 0, sizeof(epoll_event)); - ev.data.ptr = handler; - ev.events = vsb2epoll(events); - if (epoll_ctl(_epollFd, EPOLL_CTL_MOD, fd, &ev) != 0) - { - const int err = errno; - Logger::instance->Log(Logger::ERROR, "epoll_ctl failed to update fd=", fd, ": ", strerror(err)); - return false; - } - - return true; - } - void remove(int fd) override { epoll_event ev; @@ -71,6 +55,7 @@ namespace vsockio int poll(VsbEvent* outEvents, int timeout) override { + PERF_LOG("poll"); int eventCount = epoll_wait(_epollFd, _epollEvents.get(), _maxEvents, timeout); if (eventCount == -1) @@ -87,7 +72,7 @@ namespace vsockio // and leave the list of events to main processing thread outEvents[i].ioFlags = IOEvent::None; - if ((_epollEvents[i].events & EPOLLERR) || (_epollEvents[i].events & EPOLLHUP)) + if ((_epollEvents[i].events & EPOLLERR) || (_epollEvents[i].events & EPOLLHUP) || (_epollEvents[i].events & EPOLLRDHUP)) { outEvents[i].ioFlags = static_cast(outEvents[i].ioFlags | IOEvent::Error); } @@ -105,22 +90,17 @@ namespace vsockio return eventCount; } + }; - inline uint32_t vsb2epoll(uint32_t vsbEvent) const - { - uint32_t evts = EPOLLET; - - if (vsbEvent & IOEvent::InputReady) - { - evts = evts | EPOLLIN; - } + struct EpollPollerFactory : PollerFactory + { + int _maxEvents; - if (vsbEvent & IOEvent::OutputReady) - { - evts = evts | EPOLLOUT; - } + explicit EpollPollerFactory(int maxEvents) : _maxEvents(maxEvents) {} - return evts; - } - }; + std::unique_ptr createPoller() override + { + return std::make_unique(_maxEvents); + } + }; } \ No newline at end of file diff --git a/vsock-bridge/include/eventdef.h b/vsock-bridge/include/eventdef.h index a0e7735..a15afb3 100644 --- a/vsock-bridge/include/eventdef.h +++ b/vsock-bridge/include/eventdef.h @@ -15,7 +15,6 @@ namespace vsockio struct VsbEvent { IOEvent ioFlags; - int fd; void* data; }; } \ No newline at end of file diff --git a/vsock-bridge/include/iothread.h b/vsock-bridge/include/iothread.h new file mode 100644 index 0000000..c7b8412 --- /dev/null +++ b/vsock-bridge/include/iothread.h @@ -0,0 +1,89 @@ +#pragma once + +#include "channel.h" +#include "poller.h" +#include "socket.h" +#include "threading.h" + +#include +#include +#include +#include +#include +#include + +namespace vsockio +{ + class IOThread + { + public: + explicit IOThread(size_t threadId, PollerFactory& pollerFactory) + : _id(threadId) + , _poller(pollerFactory.createPoller()) + , _events(_poller->maxEventsPerPoll()) + , _thr([this] { run(); }) + { + } + + ~IOThread() + { + _terminateFlag = true; + + if (_thr.joinable()) + { + _thr.join(); + } + } + + size_t id() const { return _id; } + + void addChannel(std::unique_ptr&& ap, std::unique_ptr&& bp); + + private: + struct PendingChannel + { + std::unique_ptr _ap; + std::unique_ptr _bp; + }; + + void run(); + void addPendingChannels(); + void addPendingChannel(PendingChannel&& pendingChannel); + void poll(); + int getPollTimeout() const; + void performIO(); + void cleanup(); + + const size_t _id; + std::atomic _terminateFlag = false; + std::unique_ptr _poller; + ThreadSafeQueue _pendingChannels; + std::unordered_set _channels; + std::unordered_set _readyChannels; + std::unordered_set _terminatedChannels; + std::vector _events; + std::thread _thr; + }; + + class IOThreadPool + { + public: + explicit IOThreadPool(size_t size, PollerFactory& pollerFactory) + { + for (size_t i = 0; i < size; ++i) { + _threads.push_back(std::make_unique(i, pollerFactory)); + } + } + + void addChannel(std::unique_ptr&& ap, std::unique_ptr&& bp) const + { + thread_local static size_t channelCount = 0; + _threads[channelCount % _threads.size()]->addChannel(std::move(ap), std::move(bp)); + ++channelCount; + } + + private: + std::vector> _threads; + }; + +} diff --git a/vsock-bridge/include/listener.h b/vsock-bridge/include/listener.h index 3fc81b5..52383df 100644 --- a/vsock-bridge/include/listener.h +++ b/vsock-bridge/include/listener.h @@ -70,14 +70,13 @@ namespace vsockio const int MAX_POLLER_EVENTS = 256; const int SO_BACKLOG = 64; - Listener(std::unique_ptr&& listenEndpoint, std::unique_ptr&& connectEndpoint, std::vector& dispatchers) + Listener(std::unique_ptr&& listenEndpoint, std::unique_ptr&& connectEndpoint, Dispatcher& dispatcher) : _fd(-1) , _listenEp(std::move(listenEndpoint)) , _connectEp(std::move(connectEndpoint)) , _events(new VsbEvent[MAX_POLLER_EVENTS]) , _listenEpClone(_listenEp->clone()) - , _dispatchers(dispatchers) - , _dispatcherIdRr(0) + , _dispatcher(dispatcher) { const int fd = _listenEp->getSocket(); if (fd < 0) @@ -181,12 +180,10 @@ namespace vsockio return; } + inPeer->onConnected(); - const int dpId = (_dispatcherIdRr++) % _dispatchers.size(); - auto* const dp = _dispatchers[dpId]; - - Logger::instance->Log(Logger::DEBUG, "Dispatcher ", dpId, " will handle channel for accepted connection fd=", inPeer->fd(), ", peer fd=", outPeer->fd()); - dp->postAddChannel(std::move(inPeer), std::move(outPeer)); + Logger::instance->Log(Logger::DEBUG, "Dispatcher will handle channel for accepted connection fd=", inPeer->fd(), ", peer fd=", outPeer->fd()); + _dispatcher.addChannel(std::move(inPeer), std::move(outPeer)); } std::unique_ptr connectToPeer() @@ -214,11 +211,17 @@ namespace vsockio auto addrAndLen = _connectEp->getAddress(); int status = connect(fd, addrAndLen.first, addrAndLen.second); - if (status == 0 || (status = errno) == EINPROGRESS) + if (status == 0) { + peer->onConnected(); Logger::instance->Log(Logger::DEBUG, "connected to remote endpoint (fd=", fd, ") with status=", status); return peer; } + if ((status = errno) == EINPROGRESS) + { + Logger::instance->Log(Logger::DEBUG, "connection to remote endpoint (fd=", fd, ") in progress"); + return peer; + } else { Logger::instance->Log(Logger::WARNING, "failed to connect to remote endpoint (fd=", fd, "): ", strerror(status)); @@ -233,7 +236,6 @@ namespace vsockio std::unique_ptr _listenEpClone; std::unique_ptr _connectEp; std::unique_ptr _events; - std::vector& _dispatchers; - uint32_t _dispatcherIdRr; + Dispatcher& _dispatcher; }; } \ No newline at end of file diff --git a/vsock-bridge/include/logger.h b/vsock-bridge/include/logger.h index 7eb44fe..73ce785 100644 --- a/vsock-bridge/include/logger.h +++ b/vsock-bridge/include/logger.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -54,52 +55,14 @@ struct Logger { _streamProvider = streamProvider; } - template - void Log(int level, const T0& m0) + template + void Log(int level, const Ts&... args) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << std::endl; - } - - template - void Log(int level, const T0& m0, const T1& m1) - { - if (level < _minLevel || _streamProvider == nullptr) return; - std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << m1 << std::endl; - } - - template - void Log(int level, const T0& m0, const T1& m1, const T2& m2) - { - if (level < _minLevel || _streamProvider == nullptr) return; - std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << m1 << m2 << std::endl; - } - - template - void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3) - { - if (level < _minLevel || _streamProvider == nullptr) return; - std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << std::endl; - } - - template - void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3, const T4& m4) - { - if (level < _minLevel || _streamProvider == nullptr) return; - std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << m4 << std::endl; - } - - template - void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3, const T4& m4, const T5& m5) - { - if (level < _minLevel || _streamProvider == nullptr) return; - std::lock_guard lk(_lock); - _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << m4 << m5 << std::endl; + auto& s = _streamProvider->startLog(level); + (s << ... << args); + s << std::endl; } }; @@ -179,4 +142,26 @@ struct RSyslogLogger : public LoggingStream std::ostream _error; std::ostream _critical; NullStream _nullStream; -}; \ No newline at end of file +}; + +struct PerfLogger +{ + const char* const _name; + const std::chrono::time_point _start; + + explicit PerfLogger(const char* name) : _name(name), _start(std::chrono::steady_clock::now()) {} + ~PerfLogger() + { + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration diff = end - _start; + Logger::instance->Log(Logger::DEBUG, "Latency ", _name, " ", diff.count(), "s"); + } +}; + +#ifdef ENABLE_VSOCKIO_PERF +#define VSOCKIO_COMBINE1(X,Y) X##Y +#define VSOCKIO_COMBINE(X,Y) VSOCKIO_COMBINE1(X,Y) +#define PERF_LOG(name) PerfLogger VSOCKIO_COMBINE(__perfLog, __LINE__){name} +#else +#define PERF_LOG(name) do {} while(0) +#endif diff --git a/vsock-bridge/include/peer.h b/vsock-bridge/include/peer.h deleted file mode 100644 index 6ea5ce6..0000000 --- a/vsock-bridge/include/peer.h +++ /dev/null @@ -1,127 +0,0 @@ -#pragma once - -#include "buffer.h" - -#include -#include -#include -#include -#include - -#include - -namespace vsockio -{ - template - struct UniquePtrQueue - { - using TPtr = std::unique_ptr; - - std::list _list; - - ssize_t _count; - - UniquePtrQueue() : _count(0) {} - - ssize_t count() const - { - return _count; - } - - TPtr& front() - { - return _list.front(); - } - - void enqueue(TPtr&& value) - { - _count++; - _list.push_back(std::move(value)); - } - - TPtr dequeue() - { - _count--; - TPtr p = std::move(_list.front()); - _list.pop_front(); - return p; - } - - bool empty() const - { - return _list.empty(); - } - }; - - template - class Peer - { - public: - Peer() = default; - - Peer(const Peer&) = delete; - Peer& operator=(const Peer&) = delete; - - virtual ~Peer() {} - - void onInputReady() - { - assert(_peer != nullptr); - - _inputReady = true; - while (readFromInput() && _peer->writeToOutput()) - ; - --_ioEventCount; - } - - void onOutputReady() - { - assert(_peer != nullptr); - - _outputReady = true; - while (writeToOutput() && _peer->readFromInput()) - ; - --_ioEventCount; - } - - inline void setPeer(Peer* p) - { - _peer = p; - } - - inline void onIoEvent() { ++_ioEventCount; } - - inline int ioEventCount() const { return _ioEventCount.load(); } - - virtual void close() = 0; - - bool closed() const { return _inputClosed && _outputClosed; } - - bool inputClosed() const { return _inputClosed; } - - bool outputClosed() const { return _outputClosed; } - - virtual void onPeerClosed() = 0; - - virtual void queue(TBuf&& buffer) = 0; - - bool queueFull() const { return _queueFull; } - virtual bool queueEmpty() const = 0; - - protected: - virtual bool readFromInput() = 0; - - virtual bool writeToOutput() = 0; - - protected: - bool _inputReady = false; - bool _outputReady = false; - bool _inputClosed = false; - bool _outputClosed = false; - bool _queueFull = false; - Peer* _peer = nullptr; - - private: - std::atomic_int _ioEventCount{0}; - }; -} \ No newline at end of file diff --git a/vsock-bridge/include/poller.h b/vsock-bridge/include/poller.h index e8af5a4..409dfc5 100644 --- a/vsock-bridge/include/poller.h +++ b/vsock-bridge/include/poller.h @@ -2,13 +2,15 @@ #include "eventdef.h" +#include + namespace vsockio { struct Poller { - virtual bool add(int fd, void* handler, uint32_t events) = 0; + virtual ~Poller() = default; - virtual bool update(int fd, void* handler, uint32_t events) = 0; + virtual bool add(int fd, void* handler) = 0; virtual void remove(int fd) = 0; @@ -16,6 +18,14 @@ namespace vsockio int maxEventsPerPoll() const { return _maxEvents; } + protected: int _maxEvents; }; + + struct PollerFactory + { + virtual ~PollerFactory() = default; + + virtual std::unique_ptr createPoller() = 0; + }; } \ No newline at end of file diff --git a/vsock-bridge/include/socket.h b/vsock-bridge/include/socket.h index fd0a03b..2b7c46f 100644 --- a/vsock-bridge/include/socket.h +++ b/vsock-bridge/include/socket.h @@ -1,8 +1,9 @@ #pragma once -#include "peer.h" +#include "buffer.h" #include "poller.h" +#include #include #include @@ -22,13 +23,13 @@ namespace vsockio std::function readImpl, std::function writeImpl, std::function closeImpl - ) : + ) : read(readImpl), write(writeImpl), close(closeImpl) {} }; - class Socket : public Peer> + class Socket { public: Socket(int fd, SocketImpl& impl); @@ -38,37 +39,66 @@ namespace vsockio ~Socket(); - inline int fd() const { return _fd; } + void readInput() + { + assert(_peer != nullptr); + _canReadMore = readFromInput(); + } + + void writeOutput() + { + assert(_peer != nullptr); + _canWriteMore = writeToOutput(); + } - void close() override; + inline void setPeer(Socket* p) + { + _peer = p; + } - bool queueEmpty() const override { return _sendQueue.empty(); } + inline int fd() const { return _fd; } void setPoller(Poller* poller) { _poller = poller; } - protected: - bool readFromInput() override; + bool connected() const { return _connected; } + void onConnected() { _connected = true; } + void checkConnected(); - bool writeToOutput() override; + bool closed() const { return _inputClosed && _outputClosed; } - void onPeerClosed() override; + bool canReadWriteMore() const { return (_canReadMore || _canWriteMore) && !closed(); } - void queue(std::unique_ptr&& buffer) override; + private: + bool readFromInput(); + bool writeToOutput(); - private: - std::unique_ptr read(); + void onPeerClosed(); - void send(Buffer& buffer); + bool read(Buffer& buffer); + bool send(Buffer& buffer); + void close(); void closeInput(); - private: + bool inputClosed() const { return _inputClosed; } + bool outputClosed() const { return _outputClosed; } + bool hasQueuedData() const { return !_buffer.consumed(); } + + Buffer& buffer() { return _buffer; } + + private: SocketImpl& _impl; - UniquePtrQueue _sendQueue; + bool _canReadMore = false; + bool _canWriteMore = false; + bool _inputClosed = false; + bool _outputClosed = false; + Socket* _peer; int _fd; + bool _connected = false; Poller* _poller = nullptr; + Buffer _buffer; }; } \ No newline at end of file diff --git a/vsock-bridge/include/threading.h b/vsock-bridge/include/threading.h index ffe45d5..cdf3346 100644 --- a/vsock-bridge/include/threading.h +++ b/vsock-bridge/include/threading.h @@ -6,112 +6,37 @@ #include #include #include +#include +#include #include #include namespace vsockio { - template - struct BlockingQueue - { - std::list _list; - std::mutex _queueLock; - std::condition_variable _signal; - int _count; - - BlockingQueue() : _count(0) {} - - void enqueue(T value) - { - { - std::lock_guard lk(_queueLock); - _list.push_back(std::move(value)); - ++_count; - } - _signal.notify_one(); - } - - T dequeue() - { - std::unique_lock lk(_queueLock); - if (_count == 0) - { - _signal.wait(lk, [this]() { return _count > 0; }); - } - - T p = _list.front(); - _list.pop_front(); - --_count; - - lk.unlock(); - return p; - } - - int count() const - { - return _count; - } - - bool empty() const - { - return _count <= 0; - } - }; - - struct WorkerThread - { - std::function _initCallback; - BlockingQueue> _taskQueue; - bool _retired = false; - uint64_t _eventsProcessed = 0; - std::thread t; - - uint64_t eventsProcessed() const { return _eventsProcessed; } - - WorkerThread(std::function initCallback) - : _initCallback(initCallback), t([this] { run(); }) - { - } - - ~WorkerThread() - { - if (t.joinable()) - { - t.join(); - } - } - - void run() - { - _initCallback(); - - while (!_retired) - { - auto action = _taskQueue.dequeue(); - action(); - _eventsProcessed++; - } - } - - void stop() - { - _retired = true; - _taskQueue.enqueue([](){}); - } - - BlockingQueue>* getQueue() - { - return &_taskQueue; - } - }; - - struct ThreadPool - { - static std::vector> threads; - static BlockingQueue>* getTaskQueue(int taskId) - { - return threads[taskId % ThreadPool::threads.size()]->getQueue(); - } - }; + template + struct ThreadSafeQueue + { + std::queue _queue; + std::mutex _queueLock; + + void enqueue(T&& value) + { + std::lock_guard lk(_queueLock); + _queue.push(std::move(value)); + } + + std::optional dequeue() + { + std::lock_guard lk(_queueLock); + if (_queue.empty()) + { + return std::nullopt; + } + + T result = std::move(_queue.front()); + _queue.pop(); + return result; + } + }; } \ No newline at end of file diff --git a/vsock-bridge/include/vsock-bridge.h b/vsock-bridge/include/vsock-bridge.h index a6ecbb7..6ca6af5 100644 --- a/vsock-bridge/include/vsock-bridge.h +++ b/vsock-bridge/include/vsock-bridge.h @@ -1,7 +1,8 @@ #pragma once #include "config.h" -#include "peer.h" +#include "dispatcher.h" +#include "iothread.h" #include "listener.h" #include "logger.h" #include "socket.h" diff --git a/vsock-bridge/src/CMakeLists.txt b/vsock-bridge/src/CMakeLists.txt index 61a7f53..380270c 100644 --- a/vsock-bridge/src/CMakeLists.txt +++ b/vsock-bridge/src/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.8) -add_library (vsock-io "socket.cpp" "logger.cpp" "epoll_poller.cpp") +add_library (vsock-io "socket.cpp" "channel.cpp" "iothread.cpp" "logger.cpp" "epoll_poller.cpp") add_executable (vsock-bridge "vsock-bridge.cpp" "config.cpp" "global.cpp") target_link_libraries(vsock-bridge vsock-io pthread -static-libgcc -static-libstdc++) diff --git a/vsock-bridge/src/channel.cpp b/vsock-bridge/src/channel.cpp new file mode 100644 index 0000000..345df50 --- /dev/null +++ b/vsock-bridge/src/channel.cpp @@ -0,0 +1,16 @@ +#include + +namespace vsockio +{ + void DirectChannel::performIO() + { + // Try reading from and writing to both sockets. + // This is less efficient, but keeps the logic simple. + + _a->readInput(); + _b->readInput(); + _a->writeOutput(); + _b->writeOutput(); + } + +} \ No newline at end of file diff --git a/vsock-bridge/src/config.cpp b/vsock-bridge/src/config.cpp index ab83780..eaa4693 100644 --- a/vsock-bridge/src/config.cpp +++ b/vsock-bridge/src/config.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace vsockproxy @@ -16,13 +17,7 @@ namespace vsockproxy socks-proxy: service: socks listen: vsock://-1:3305 - - file-server: - service: file - listen: vsock://-1:3306 - mapping: - - config:/etc/uidoperator/config.json - - secrets:/etc/uidoperator/secrets.json + connect: tcp://127.0.0.1:3306 operator-service: service: direct @@ -36,70 +31,63 @@ namespace vsockproxy */ - struct yaml_line + struct YamlLine { - bool is_empty; - bool is_comment; - bool is_list_element; - int level; - std::string key; - std::string value; + bool _isEmpty; + bool _isListElement; + int _level; + std::string _key; + std::string _value; }; - std::string name_service_type(service_type t) + static std::string nameServiceType(ServiceType t) { switch (t) { - case service_type::DIRECT_PROXY: return "direct"; - case service_type::SOCKS_PROXY: return "socks"; - case service_type::FILE: return "file"; + case ServiceType::DIRECT_PROXY: return "direct"; default: return "unknown"; } } - std::string name_scheme(endpoint_scheme t) + static std::string nameEndpointScheme(EndpointScheme t) { switch (t) { - case endpoint_scheme::TCP4: return "tcp"; - case endpoint_scheme::VSOCK: return "vsock"; + case EndpointScheme::TCP4: return "tcp"; + case EndpointScheme::VSOCK: return "vsock"; default: return "unknown"; } } - uint16_t try_str2short(std::string s, uint16_t default_value) + static std::optional trystrtous(const std::string& s) { - if (s.size() == 0) return default_value; - - uint16_t value = 0; - for (int i = 0; i < s.size(); i++) - { - if (s[i] >= '0' && s[i] <= '9') - { - value *= 10; - value += s[i] - '0'; - } - else - { - return default_value; - } - } - - return value; + if (s.empty()) return std::nullopt; + + try + { + const auto result = std::stoul(s); + if (result > std::numeric_limits::max()) + { + return std::nullopt; + } + return static_cast(result); + } + catch (...) + { + return std::nullopt; + } } - yaml_line nextline(std::ifstream& s) + static YamlLine nextLine(std::ifstream& s) { - yaml_line y; - y.is_empty = true; - y.is_comment = false; - y.is_list_element = false; + YamlLine y; + y._isEmpty = true; + y._isListElement = false; for (std::string line; std::getline(s, line); ) { - y.is_empty = true; - y.is_comment = false; - y.is_list_element = false; + y._isEmpty = true; + y._isListElement = false; if (line == "---") continue; @@ -110,23 +98,22 @@ namespace vsockproxy { if (line[i] != ' ' && line[i] != '\t' && line[i] != '\r' && line[i] != '\n') { - if (y.is_empty) + if (y._isEmpty) { // first character - y.is_empty = false; - y.level = i; + y._isEmpty = false; + y._level = i; if (line[i] == '#') { - y.is_comment = true; break; } else if (line[i] == '-') { - y.is_list_element = true; + y._isListElement = true; } state = 1; - if (y.is_list_element) continue; // skip '-' + if (y._isListElement) continue; // skip '-' } if (state == 1) @@ -148,44 +135,53 @@ namespace vsockproxy } } - if (key.size() > 0) + if (!key.empty()) { - y.key = key; - y.value = value; + y._key = key; + y._value = value; break; } } return y; } - void tryparse_endpoint(std::string& value, endpoint& out_ep) + static std::optional tryParseEndpoint(const std::string& value) { + EndpointConfig endpointConfig; size_t p = value.find(':'); if (p != value.npos) { - std::string scheme = value.substr(0, p); + const std::string scheme = value.substr(0, p); if (scheme == "vsock") { - out_ep.scheme = endpoint_scheme::VSOCK; + endpointConfig._scheme = EndpointScheme::VSOCK; } else if (scheme == "tcp") { - out_ep.scheme = endpoint_scheme::TCP4; + endpointConfig._scheme = EndpointScheme::TCP4; } } p += 3; // skip '://' - size_t p2 = value.find(':', p); + const size_t p2 = value.find(':', p); if (p2 != value.npos) { - out_ep.address = value.substr(p, p2 - p); - out_ep.port = try_str2short(value.substr(p2 + 1), 0); + endpointConfig._address = value.substr(p, p2 - p); + const auto port = trystrtous(value.substr(p2 + 1)); + if (!port) + { + Logger::instance->Log(Logger::CRITICAL, "invalid port number: ", value.substr(p2 + 1)); + return std::nullopt; + } + endpointConfig._port = *port; } + + return endpointConfig; } - std::vector load_config(std::string filepath) + std::vector loadConfig(const std::string& filepath) { - std::vector services; + std::vector services; std::ifstream f; f.open(filepath); @@ -196,70 +192,70 @@ namespace vsockproxy return services; } - int level_indent = -1; + int levelIndent = -1; - service_description cs; - cs.type = service_type::UNKNOWN; - cs.listen_ep.scheme = endpoint_scheme::UNKNOWN; + ServiceDescription cs; while (true) { - yaml_line line = nextline(f); - if (line.is_empty) break; + YamlLine line = nextLine(f); + if (line._isEmpty) break; - if (line.level == 0) + if (line._level == 0) { - if (cs.type != service_type::UNKNOWN) + if (cs._type != ServiceType::UNKNOWN) { services.push_back(cs); } - cs = service_description(); - cs.type = service_type::UNKNOWN; - cs.listen_ep.scheme = endpoint_scheme::UNKNOWN; - cs.name = line.key; + cs = ServiceDescription(); + cs._name = line._key; } else { - if (level_indent == -1) + if (levelIndent == -1) { // first time we find non-zero indentation, // use this to determine level - level_indent = line.level; + levelIndent = line._level; } - int level = line.level / level_indent; + const int level = line._level / levelIndent; if (level == 1) { - if (line.key == "service") + if (line._key == "service") { - if (line.value == "socks") - cs.type = service_type::SOCKS_PROXY; - else if (line.value == "file") - cs.type = service_type::FILE; - else if (line.value == "direct") - cs.type = service_type::DIRECT_PROXY; + if (line._value == "direct") + cs._type = ServiceType::DIRECT_PROXY; else - cs.type = service_type::UNKNOWN; + { + Logger::instance->Log(Logger::CRITICAL, "unknown service type for service: ", cs._name); + return {}; + } } - else if (line.key == "listen") + else if (line._key == "listen") { - tryparse_endpoint(line.value, cs.listen_ep); + const auto endpoint = tryParseEndpoint(line._value); + if (!endpoint) + { + Logger::instance->Log(Logger::CRITICAL, "failed to parse listen endpoint config: ", line._value, " for service: ", cs._name); + return {}; + } + cs._listenEndpoint = *endpoint; } - else if (line.key == "connect") + else if (line._key == "connect") { - tryparse_endpoint(line.value, cs.connect_ep); - } - } - else if (level == 2) - { - if (line.is_list_element) - { - cs.mapping.push_back(std::make_pair(line.key, line.value)); + const auto endpoint = tryParseEndpoint(line._value); + if (!endpoint) + { + Logger::instance->Log(Logger::CRITICAL, "failed to parse connect endpoint config: ", line._value, " for service: ", cs._name); + return {}; + } + cs._connectEndpoint = *endpoint; } } } } - if (cs.type != service_type::UNKNOWN) + if (cs._type != ServiceType::UNKNOWN) { services.push_back(cs); } @@ -267,19 +263,13 @@ namespace vsockproxy return services; } - std::string describe(service_description& sd) + std::string describe(const ServiceDescription& sd) { std::stringstream ss; - ss << sd.name - << "\n type: " << name_service_type(sd.type) - << "\n listen: " << name_scheme(sd.listen_ep.scheme) << "://" << sd.listen_ep.address << ":" << sd.listen_ep.port - << "\n connect: " << name_scheme(sd.connect_ep.scheme) << "://" << sd.connect_ep.address << ":" << sd.connect_ep.port - << "\n mapping:"; - - for (auto& p : sd.mapping) - { - ss << "\n - " << p.first << ":" << p.second; - } + ss << sd._name + << "\n type: " << nameServiceType(sd._type) + << "\n listen: " << nameEndpointScheme(sd._listenEndpoint._scheme) << "://" << sd._listenEndpoint._address << ":" << sd._listenEndpoint._port + << "\n connect: " << nameEndpointScheme(sd._connectEndpoint._scheme) << "://" << sd._connectEndpoint._address << ":" << sd._connectEndpoint._port; return ss.str(); } diff --git a/vsock-bridge/src/global.cpp b/vsock-bridge/src/global.cpp index f5ef2d5..0be0b9b 100644 --- a/vsock-bridge/src/global.cpp +++ b/vsock-bridge/src/global.cpp @@ -1,14 +1,10 @@ -#include #include -#include + #include +#include using namespace vsockio; -thread_local MemoryArena* BufferManager::arena = new MemoryArena(); - -std::vector> ThreadPool::threads; - SocketImpl* SocketImpl::singleton = new SocketImpl( /*read: */ [](int fd, void* buf, int len) { return ::read(fd, buf, len); }, /*write:*/ [](int fd, void* buf, int len) { return ::write(fd, buf, len); }, diff --git a/vsock-bridge/src/iothread.cpp b/vsock-bridge/src/iothread.cpp new file mode 100644 index 0000000..439662d --- /dev/null +++ b/vsock-bridge/src/iothread.cpp @@ -0,0 +1,118 @@ +#include + +namespace vsockio +{ + void IOThread::addChannel(std::unique_ptr&& ap, std::unique_ptr&& bp) + { + _pendingChannels.enqueue({std::move(ap), std::move(bp)}); + } + + void IOThread::run() + { + while (!_terminateFlag.load(std::memory_order_relaxed)) + { + addPendingChannels(); + poll(); + performIO(); + cleanup(); + } + } + + void IOThread::addPendingChannels() + { + while (true) + { + auto pendingChannel = _pendingChannels.dequeue(); + if (!pendingChannel) + { + break; + } + + addPendingChannel(std::move(*pendingChannel)); + } + } + + void IOThread::addPendingChannel(PendingChannel&& pendingChannel) + { + thread_local static int channelId = 0; + + Logger::instance->Log(Logger::DEBUG, "iothread id=", id(), " creating channel id=", channelId, ", a.fd=", pendingChannel._ap->fd(), ", b.fd=", pendingChannel._bp->fd()); + auto channel = std::make_unique(channelId, std::move(pendingChannel._ap), std::move(pendingChannel._bp)); + ++channelId; + + channel->_a->setPoller(_poller.get()); + channel->_b->setPoller(_poller.get()); + if (!_poller->add(channel->_a->fd(), (void*)&channel->_ha) || + !_poller->add(channel->_b->fd(), (void*)&channel->_hb)) + { + return; + } + + _channels.insert(channel.release()); + } + + void IOThread::poll() + { + const int eventCount = _poller->poll(_events.data(), getPollTimeout()); + if (eventCount == -1) { + Logger::instance->Log(Logger::CRITICAL, "Poller returns error."); + return; + } + + for (int i = 0; i < eventCount; i++) { + auto* handle = static_cast(_events[i].data); + auto* channel = handle->_channel; + _readyChannels.insert(channel); + + Socket& s = channel->getSocket(handle->_fd); + if ((_events[i].ioFlags & (IOEvent::OutputReady | IOEvent::Error)) && !s.connected()) + { + s.checkConnected(); + } + } + } + + int IOThread::getPollTimeout() const + { + return _readyChannels.empty() ? 1 : 0; + } + + void IOThread::performIO() + { + for (auto it = _readyChannels.begin(); it != _readyChannels.end(); ) + { + auto* channel = *it; + channel->performIO(); + if (!channel->canReadWriteMore()) + { + it = _readyChannels.erase(it); + } + else + { + ++it; + } + + if (channel->canBeTerminated()) + { + _terminatedChannels.insert(channel); + } + } + } + + void IOThread::cleanup() + { + if (_terminatedChannels.empty()) + { + return; + } + + for (auto* channel : _terminatedChannels) + { + _channels.erase(channel); + _readyChannels.erase(channel); + delete channel; + } + + _terminatedChannels.clear(); + } +} diff --git a/vsock-bridge/src/socket.cpp b/vsock-bridge/src/socket.cpp index cb79d65..25aefb9 100644 --- a/vsock-bridge/src/socket.cpp +++ b/vsock-bridge/src/socket.cpp @@ -8,236 +8,201 @@ namespace vsockio { - Socket::Socket(int fd, SocketImpl& impl) - : _fd(fd) - , _impl(impl) - { - assert(_fd >= 0); - } - - bool Socket::readFromInput() - { - if (_peer->outputClosed() && !inputClosed()) - { - Logger::instance->Log(Logger::DEBUG, "[socket] readToInput detected output peer closed, closing input (fd=", _fd, ")"); - closeInput(); - return false; - } - - if (_inputClosed) return false; - - bool hasInput = false; - while (!_inputClosed && _inputReady && !_peer->queueFull()) - { - std::unique_ptr buffer{ read() }; - if (buffer && !buffer->empty()) - { - _peer->queue(std::move(buffer)); - hasInput = true; - } - } - - if (_inputClosed) - { - Logger::instance->Log(Logger::DEBUG, "[socket] readToInput detected input closed, closing (fd=", _fd, ")"); - close(); - } - - return hasInput; - } - - bool Socket::writeToOutput() - { - if (_outputClosed) return false; - - while (!_outputClosed && _outputReady && !_sendQueue.empty()) - { - std::unique_ptr& buffer = _sendQueue.front(); - - // received termination signal from peer - if (buffer->empty()) - { - Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput dequeued a termination buffer (fd=", _fd, ")"); - _sendQueue.dequeue(); - close(); - break; - } - else - { - send(*buffer); - if (buffer->consumed()) - { - _sendQueue.dequeue(); - _queueFull = false; - } - } - } - - if (_peer->closed()) - { - if (_sendQueue.empty()) - { - Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput detected input peer is closed, closing (fd=", _fd, ")"); - close(); - } - else if (!_peer->queueEmpty()) - { - // Peer has some queued data they never received - // Assuming this data is critical for the protocol, it should be ok to abort the connection straight away - Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput detected input peer is closed while having data remaining, closing (fd=", _fd, ")"); - close(); - } - } - - return _sendQueue.empty(); - } - - void Socket::queue(std::unique_ptr&& buffer) - { - _sendQueue.enqueue(std::move(buffer)); - - // to simplify logic we allow only 1 buffer for socket sinks - _queueFull = true; - } - - std::unique_ptr Socket::read() - { - std::unique_ptr buffer{ BufferManager::getBuffer() }; - - while (true) - { - const int bytesRead = _impl.read(_fd, buffer->tail(), buffer->remainingCapacity()); - int err = 0; - if (bytesRead > 0) - { - // New content read - // update byte count and enlarge buffer if needed - - //Logger::instance->Log(Logger::DEBUG, "[socket] read returns ", bytesRead, " (fd=", _fd, ")"); - buffer->produce(bytesRead); - if (!buffer->ensureCapacity()) - { - break; - } - } - else if (bytesRead == 0) - { - // Source closed - - Logger::instance->Log(Logger::DEBUG, "[socket] read returns 0, closing input (fd=", _fd, ")"); - closeInput(); - break; - } - else if ((err = errno) == EAGAIN || err == EWOULDBLOCK) - { - // No new data - - _inputReady = false; - break; - } - else - { - // Error - - Logger::instance->Log(Logger::WARNING, "[socket] error on read, closing input (fd=", _fd, "): ", strerror(err)); - closeInput(); - break; - } - } - - return buffer; - } - - void Socket::send(Buffer& buffer) - { - while (!buffer.consumed()) - { - const int bytesWritten = _impl.write(_fd, buffer.head(), buffer.headLimit()); - - int err = 0; - if (bytesWritten > 0) - { - // Some data written to downstream - // log bytes written and move cursor forward - - //Logger::instance->Log(Logger::DEBUG, "[socket] write returns ", bytesWritten, " (fd=", _fd, ")"); - buffer.consume(bytesWritten); - } - else if((err = errno) == EAGAIN || err == EWOULDBLOCK) - { - // Write blocked - _outputReady = false; - break; - } - else - { - // Error - - Logger::instance->Log(Logger::WARNING, "[socket] error on send, closing (fd=", _fd, "): ", strerror(err)); - close(); - break; - } - } - - } - - void Socket::closeInput() - { - _inputClosed = true; - } - - void Socket::close() - { - _inputReady = false; - _outputReady = false; - - if (!closed()) - { - _inputClosed = true; - _outputClosed = true; - - if (_poller) - { - // epoll is meant to automatically deregister sockets on close, but apparently some systems - // have bugs around this, so do it explicitly - Logger::instance->Log(Logger::DEBUG, "[socket] remove from poller (fd=", _fd, ")"); - _poller->remove(_fd); - } - - Logger::instance->Log(Logger::DEBUG, "[socket] close, fd=", _fd); - _impl.close(_fd); - if (_peer != nullptr) - { - _peer->onPeerClosed(); - } - } - } - - void Socket::onPeerClosed() - { - if (!closed()) - { - Logger::instance->Log(Logger::DEBUG, "[socket] sending termination for (fd=", _fd, ")"); - std::unique_ptr termination{ BufferManager::getEmptyBuffer() }; - queue(std::move(termination)); - - // force process the queue - _outputReady = true; - writeToOutput(); - } - } - - Socket::~Socket() - { - if (!closed()) - { - Logger::instance->Log(Logger::WARNING, "[socket] closing on destruction (fd=", _fd, ")"); - close(); - } - - if (_peer != nullptr) - { - _peer->setPeer(nullptr); - } - } + Socket::Socket(int fd, SocketImpl& impl) + : _fd(fd) + , _impl(impl) + { + assert(_fd >= 0); + } + + bool Socket::readFromInput() + { + if (!_connected) return false; + + if (_inputClosed) return false; + + const bool canReadMoreData = read(_peer->buffer()); + return canReadMoreData; + } + + bool Socket::writeToOutput() + { + if (!_connected) return false; + + if (_outputClosed) return false; + + bool canSendModeData = false; + if (!_outputClosed) { + if (!_buffer.consumed()) { + canSendModeData = send(_buffer); + if (_buffer.consumed()) + { + _buffer.reset(); + } + } + } + + if (_peer->closed() && _buffer.consumed()) + { + Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput finished draining socket, closing (fd=", _fd, ")"); + close(); + } + + return canSendModeData; + } + + bool Socket::read(Buffer& buffer) + { + if (!buffer.hasRemainingCapacity()) return false; + + PERF_LOG("read"); + const int bytesRead = _impl.read(_fd, buffer.tail(), buffer.remainingCapacity()); + int err = 0; + if (bytesRead > 0) + { + // New content read + + //Logger::instance->Log(Logger::DEBUG, "[socket] read returns ", bytesRead, " (fd=", _fd, ")"); + buffer.produce(bytesRead); + return true; + } + else if (bytesRead == 0) + { + // Source closed + + Logger::instance->Log(Logger::DEBUG, "[socket] read returns 0, closing (fd=", _fd, ")"); + close(); + return false; + } + else if ((err = errno) == EAGAIN || err == EWOULDBLOCK) + { + // No new data + + return false; + } + else + { + // Error + + Logger::instance->Log(Logger::WARNING, "[socket] error on read, closing (fd=", _fd, "): ", err, ", ", strerror(err)); + close(); + return false; + } + } + + bool Socket::send(Buffer& buffer) + { + if (buffer.consumed()) return false; + + do + { + PERF_LOG("send"); + const int bytesWritten = _impl.write(_fd, buffer.head(), buffer.remainingDataSize()); + + int err = 0; + if (bytesWritten > 0) + { + // Some data written to downstream + // log bytes written and move cursor forward + + //Logger::instance->Log(Logger::DEBUG, "[socket] write returns ", bytesWritten, " (fd=", _fd, ")"); + buffer.consume(bytesWritten); + } + else if((err = errno) == EAGAIN || err == EWOULDBLOCK) + { + // Write blocked + return false; + } + else + { + // Error + + Logger::instance->Log(Logger::WARNING, "[socket] error on send, closing (fd=", _fd, "): ", strerror(err)); + close(); + return false; + } + } while (!buffer.consumed()); + + return true; + } + + void Socket::checkConnected() + { + char c; + const int bytesWritten = _impl.write(_fd, &c, 0); + int err = errno; + if (bytesWritten == 0) + { + _connected = true; + Logger::instance->Log(Logger::WARNING, "[socket] connected (fd=", _fd, ")"); + } + else if (err != EAGAIN && err != EWOULDBLOCK) + { + Logger::instance->Log(Logger::WARNING, "[socket] connection error, closing (fd=", _fd, "): ", err, ", ", strerror(err)); + close(); + } + } + + void Socket::closeInput() + { + _inputClosed = true; + } + + void Socket::close() + { + if (!closed()) + { + _inputClosed = true; + _outputClosed = true; + + if (_poller) + { + // epoll is meant to automatically deregister sockets on close, but apparently some systems + // have bugs around this, so do it explicitly + Logger::instance->Log(Logger::DEBUG, "[socket] remove from poller (fd=", _fd, ")"); + _poller->remove(_fd); + } + + Logger::instance->Log(Logger::DEBUG, "[socket] close, fd=", _fd); + _impl.close(_fd); + if (_peer != nullptr) + { + _peer->onPeerClosed(); + } + } + } + + void Socket::onPeerClosed() + { + if (!closed()) + { + Logger::instance->Log(Logger::DEBUG, "[socket] onPeerClosed draining socket (fd=", _fd, ")"); + closeInput(); + + // force process the output queue + writeToOutput(); + + if (_peer->hasQueuedData()) + { + // Peer has some queued data they never received + // Assuming this data is critical for the protocol, it should be ok to abort the connection straight away + Logger::instance->Log(Logger::DEBUG, "[socket] onPeerClosed detected input peer is closed while having data remaining, closing (fd=", _fd, ")"); + close(); + } + } + } + + Socket::~Socket() + { + if (!closed()) + { + Logger::instance->Log(Logger::WARNING, "[socket] closing on destruction (fd=", _fd, ")"); + close(); + } + + if (_peer != nullptr) + { + _peer->setPeer(nullptr); + } + } } \ No newline at end of file diff --git a/vsock-bridge/src/vsock-bridge.cpp b/vsock-bridge/src/vsock-bridge.cpp index f5bdda2..ed7cd6d 100644 --- a/vsock-bridge/src/vsock-bridge.cpp +++ b/vsock-bridge/src/vsock-bridge.cpp @@ -5,20 +5,18 @@ using namespace vsockproxy; #define VSB_MAX_POLL_EVENTS 256 -void sigpipe_handler(int unused) +static void sigpipe_handler(int unused) { Logger::instance->Log(Logger::DEBUG, "SIGPIPE received"); } -std::vector serviceThreads; - -std::unique_ptr createEndpoint(endpoint_scheme scheme, std::string address, uint16_t port) +static std::unique_ptr createEndpoint(EndpointScheme scheme, const std::string& address, uint16_t port) { - if (scheme == endpoint_scheme::TCP4) + if (scheme == EndpointScheme::TCP4) { return std::move(std::make_unique(address, port)); } - else if (scheme == endpoint_scheme::VSOCK) + else if (scheme == EndpointScheme::VSOCK) { int cid = std::atoi(address.c_str()); return std::move(std::make_unique(cid, port)); @@ -29,7 +27,7 @@ std::unique_ptr createEndpoint(endpoint_scheme scheme, std::string add } } -Listener* create_listener(std::vector& dispatchers, endpoint_scheme inScheme, std::string inAddress, uint16_t inPort, endpoint_scheme outScheme, std::string outAddress, uint16_t outPort) +static std::unique_ptr createListener(Dispatcher& dispatcher, EndpointScheme inScheme, const std::string& inAddress, uint16_t inPort, EndpointScheme outScheme, const std::string& outAddress, uint16_t outPort) { auto listenEp { createEndpoint(inScheme, inAddress, inPort) }; auto connectEp{ createEndpoint(outScheme, outAddress, outPort) }; @@ -46,79 +44,66 @@ Listener* create_listener(std::vector& dispatchers, endpoint_scheme } else { - return new Listener(std::move(listenEp), std::move(connectEp), dispatchers); + return std::make_unique(std::move(listenEp), std::move(connectEp), dispatcher); } } -void start_services(std::vector& services, int numIOThreads, int numWorkers) +static void startServices(const std::vector& services, int numWorkers) { Logger::instance->Log(Logger::INFO, "Starting ", numWorkers, " worker threads..."); - for (int i = 0; i < numWorkers; i++) - { - auto t = std::make_unique( - /*init:*/ []() { - BufferManager::arena->init(512, 2000); - } - ); - ThreadPool::threads.push_back(std::move(t)); - } + EpollPollerFactory pollerFactory{VSB_MAX_POLL_EVENTS}; + IOThreadPool threadPool{(size_t)numWorkers, pollerFactory}; + Dispatcher dispatcher{threadPool}; + std::vector> listeners; + std::vector listenerThreads; - for (auto& sd : services) + for (const auto& sd : services) { - std::vector* dispatchers = new std::vector(); - for (int i = 0; i < 1; i++) - { - Dispatcher* d = new Dispatcher(i, new EpollPoller(VSB_MAX_POLL_EVENTS)); - dispatchers->push_back(d); - } - - Logger::instance->Log(Logger::INFO, "Starting service: ", sd.name); - Listener* listener = create_listener( - *dispatchers, - /*inScheme:*/ sd.listen_ep.scheme, - /*inAddress:*/ sd.listen_ep.address, - /*inPort:*/ sd.listen_ep.port, - /*outScheme:*/ sd.connect_ep.scheme, - /*outAddress:*/ sd.connect_ep.address, - /*outPort:*/ sd.connect_ep.port + Logger::instance->Log(Logger::INFO, "Starting service: ", sd._name); + auto listener = createListener( + dispatcher, + /*inScheme:*/ sd._listenEndpoint._scheme, + /*inAddress:*/ sd._listenEndpoint._address, + /*inPort:*/ sd._listenEndpoint._port, + /*outScheme:*/ sd._connectEndpoint._scheme, + /*outAddress:*/ sd._connectEndpoint._address, + /*outPort:*/ sd._connectEndpoint._port ); - if (listener == nullptr) + if (!listener) { - Logger::instance->Log(Logger::CRITICAL, "failed to start listener for ", sd.name); + Logger::instance->Log(Logger::CRITICAL, "failed to start listener for ", sd._name); exit(1); } - std::thread* listenerThread = new std::thread(&Listener::run, listener); - std::thread* dispatcherThread = new std::thread(&Dispatcher::run, listener->_dispatchers[0]); - serviceThreads.push_back(listenerThread); + listenerThreads.emplace_back(&Listener::run, listener.get()); + listeners.emplace_back(std::move(listener)); } - for (auto* t : serviceThreads) + for (auto& t : listenerThreads) { - if (t->joinable()) - t->join(); + if (t.joinable()) + t.join(); } } -void show_help() +static void showHelp() { std::cout - << "usage: vsockpx -c [-d] [--log-level [0-3]] [--num-threads n] [--iothreads n] [...]\n" + << "usage: vsockpx -c [-d] [--log-level n] [--workers n] [...]\n" << " -c/--config: path to configuration file\n" << " -d/--daemon: running in daemon mode\n" - << " --log-level: log level, 0=debug, 1=info, 2=warning, 3=error, 4=critical\n" - << " --iothreads: number of io threads, positive integer\n" - << " --workers: number of worker threads, positive integer\n" + << " --log-level: log level, 0=debug, 1=info, 2=warning, 3=error, 4=critical (default: info)\n" + << " --workers: number of IO worker threads, positive integer (default: 1)\n" << std::flush; } -void quit_bad_args(const char* reason, bool showhelp) +static void quitBadArgs(const char* reason, bool showhelp) { std::cout << reason << std::endl; if (showhelp) - show_help(); + showHelp(); exit(1); } @@ -128,14 +113,13 @@ int main(int argc, char* argv[]) sigaction(SIGPIPE, (struct sigaction*)&sig, NULL); bool daemonize = false; - std::string config_path; - int min_log_level = 1; - int num_worker_threads = 1; - int num_iothreads = 1; + std::string configPath; + int minLogLevel = 1; + int numWorkerThreads = 1; if (argc < 2) { - show_help(); + showHelp(); return 1; } @@ -143,7 +127,7 @@ int main(int argc, char* argv[]) { if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - show_help(); + showHelp(); exit(0); } @@ -156,23 +140,23 @@ int main(int argc, char* argv[]) { if (i + 1 == argc) { - quit_bad_args("no filepath followed by --config", false); + quitBadArgs("no filepath followed by --config", false); } - config_path = std::string(argv[++i]); + configPath = std::string(argv[++i]); } else if (strcmp(argv[i], "--workers") == 0) { if (i + 1 == argc) { - quit_bad_args("no number followed by --workers", false); + quitBadArgs("no number followed by --workers", false); } - num_worker_threads = std::stoi(std::string(argv[++i])); + numWorkerThreads = std::stoi(std::string(argv[++i])); - if (num_worker_threads == 0) + if (numWorkerThreads == 0) { - quit_bad_args("--workers should be at least 1", false); + quitBadArgs("--workers should be at least 1", false); } } @@ -180,46 +164,26 @@ int main(int argc, char* argv[]) { if (i + 1 == argc) { - quit_bad_args("no log level followed by --log-level", false); - } - try - { - min_log_level = std::stoi(std::string(argv[++i])); - } - catch (std::invalid_argument _) - { - quit_bad_args("invalid log level, must be 0, 1, 2, 3 or 4", false); - } - if (min_log_level < 0 || min_log_level > 4) - { - quit_bad_args("invalid log level, must be 0, 1, 2, 3 or 4", false); - } - } - - else if (strcmp(argv[i], "--iothreads") == 0) - { - if (i + 1 == argc) - { - quit_bad_args("no number followed by --iothreads", false); + quitBadArgs("no log level followed by --log-level", false); } try { - num_iothreads = std::stoi(std::string(argv[++i])); + minLogLevel = std::stoi(std::string(argv[++i])); } catch (std::invalid_argument _) { - quit_bad_args("invalid io thread count, must be number > 0", false); + quitBadArgs("invalid log level, must be 0, 1, 2, 3 or 4", false); } - if (num_iothreads <= 0) + if (minLogLevel < 0 || minLogLevel > 4) { - quit_bad_args("invalid io thread count, must be number > 0", false); + quitBadArgs("invalid log level, must be 0, 1, 2, 3 or 4", false); } } } - if (config_path.empty()) + if (configPath.empty()) { - quit_bad_args("no configuration file, use -c/--config or --help for more info.", false); + quitBadArgs("no configuration file, use -c/--config or --help for more info.", false); } if (daemonize) @@ -232,7 +196,7 @@ int main(int argc, char* argv[]) umask(0); - Logger::instance->setMinLevel(min_log_level); + Logger::instance->setMinLevel(minLogLevel); Logger::instance->setStreamProvider(new RSyslogLogger("vsockpx")); sid = setsid(); @@ -244,19 +208,19 @@ int main(int argc, char* argv[]) } else { - Logger::instance->setMinLevel(min_log_level); + Logger::instance->setMinLevel(minLogLevel); Logger::instance->setStreamProvider(new StdoutLogger()); } - std::vector services = load_config(config_path); + const std::vector services = loadConfig(configPath); if (services.empty()) { Logger::instance->Log(Logger::CRITICAL, "No services are configured, quitting."); - exit(0); + exit(1); } - start_services(services, num_iothreads, num_worker_threads); + startServices(services, numWorkerThreads); return 0; } \ No newline at end of file diff --git a/vsock-bridge/test/CMakeLists.txt b/vsock-bridge/test/CMakeLists.txt index a79dc50..3ac2652 100644 --- a/vsock-bridge/test/CMakeLists.txt +++ b/vsock-bridge/test/CMakeLists.txt @@ -4,7 +4,12 @@ include_directories (tests ${CMAKE_CURRENT_SOURCE_DIR}/../include ) -add_executable (tests testmain.cpp) +add_executable (tests + testmain.cpp + test_buffer.cpp + test_channel.cpp + test_threading.cpp +) target_link_libraries (tests vsock-io pthread) diff --git a/vsock-bridge/test/mocks.h b/vsock-bridge/test/mocks.h deleted file mode 100644 index b45725e..0000000 --- a/vsock-bridge/test/mocks.h +++ /dev/null @@ -1,103 +0,0 @@ -#pragma once - -#include - -using namespace vsockio; - -struct MockPoller : public Poller -{ - struct Fd { - int fd; - void* handler; - uint32_t listeningEvents; - uint32_t events; - bool triggerNextTime; - - Fd() : events(IOEvent::None) {} - }; - - bool _inputReady; - bool _outputReady; - std::unordered_map _fdMap; - bool _triggerNextTime; - - MockPoller(int maxEvents) - { - _maxEvents = maxEvents; - } - - bool add(int fd, void* handler, uint32_t events) override - { - Logger::instance->Log(Logger::INFO, "add: ", fd, ",", (uint64_t)handler, ",", events); - _fdMap[fd].fd = fd; - _fdMap[fd].handler = handler; - _fdMap[fd].listeningEvents = events; - return true; - } - - bool update(int fd, void* handler, uint32_t events) override - { - Logger::instance->Log(Logger::INFO, "update: ", fd, ",", (uint64_t)handler, ",", events); - _fdMap[fd].fd = fd; - _fdMap[fd].handler = handler; - _fdMap[fd].listeningEvents = events; - return true; - } - - void remove(int fd) override - { - Logger::instance->Log(Logger::INFO, "remove: ", fd); - _fdMap.erase(fd); - } - - int poll(VsbEvent* outEvents, int timeout) override - { - int numEvents = 0; - for (auto& fd : _fdMap) - { - if (fd.second.triggerNextTime) - { - fd.second.triggerNextTime = false; - outEvents[numEvents].fd = fd.second.fd; - outEvents[numEvents].data = fd.second.handler; - outEvents[numEvents].ioFlags = (IOEvent)(fd.second.events & fd.second.listeningEvents); - numEvents++; - } - } - return numEvents; - } - - void setInputReady(int fd, bool ready) - { - if (ready) - { - auto oldEvents = _fdMap[fd].events; - _fdMap[fd].events |= IOEvent::InputReady; - if (oldEvents != _fdMap[fd].events) - { - _fdMap[fd].triggerNextTime = true; - } - } - else - { - _fdMap[fd].events &= ~IOEvent::InputReady; - } - } - - void setOutputReady(int fd, bool ready) - { - if (ready) - { - auto oldEvents = _fdMap[fd].events; - _fdMap[fd].events |= IOEvent::OutputReady; - if (oldEvents != _fdMap[fd].events) - { - _fdMap[fd].triggerNextTime = true; - } - } - else - { - _fdMap[fd].events &= ~IOEvent::OutputReady; - } - } -}; diff --git a/vsock-bridge/test/test_buffer.cpp b/vsock-bridge/test/test_buffer.cpp new file mode 100644 index 0000000..2bcf7b4 --- /dev/null +++ b/vsock-bridge/test/test_buffer.cpp @@ -0,0 +1,100 @@ +#include + +#include "catch.hpp" + +using namespace vsockio; + +SCENARIO("Buffer") +{ + Buffer buffer; + + GIVEN("Newly created buffer") + { + THEN("Buffer has basic initial state") + { + REQUIRE(buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == Buffer::BUFFER_SIZE); + REQUIRE(buffer.remainingDataSize() == 0); + REQUIRE(buffer.consumed()); + } + } + + GIVEN("Some data produced into the buffer") + { + buffer.produce(5); + + THEN("Buffer tail shifts, but head stays in place") + { + REQUIRE(buffer.head() == buffer._data.data()); + REQUIRE(buffer.tail() == buffer._data.data() + 5); + REQUIRE(buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == Buffer::BUFFER_SIZE - 5); + REQUIRE(buffer.remainingDataSize() == 5); + REQUIRE(!buffer.consumed()); + } + } + + GIVEN("Some data produced into the buffer and then partially consumed") + { + buffer.produce(5); + buffer.consume(3); + + THEN("Buffer head and tail shift accordingly") + { + REQUIRE(buffer.head() == buffer._data.data() + 3); + REQUIRE(buffer.tail() == buffer._data.data() + 5); + REQUIRE(buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == Buffer::BUFFER_SIZE - 5); + REQUIRE(buffer.remainingDataSize() == 2); + REQUIRE(!buffer.consumed()); + } + } + + GIVEN("Some data produced into the buffer and then fully consumed") + { + buffer.produce(5); + buffer.consume(5); + + THEN("Buffer head and tail shift accordingly") + { + REQUIRE(buffer.head() == buffer._data.data() + 5); + REQUIRE(buffer.tail() == buffer._data.data() + 5); + REQUIRE(buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == Buffer::BUFFER_SIZE - 5); + REQUIRE(buffer.remainingDataSize() == 0); + REQUIRE(buffer.consumed()); + } + } + + GIVEN("Buffer is completely filled with data") + { + buffer.produce(Buffer::BUFFER_SIZE); + + THEN("Buffer does not have remaining capacity") + { + REQUIRE(buffer.head() == buffer._data.data()); + REQUIRE(buffer.tail() == buffer._data.data() + Buffer::BUFFER_SIZE); + REQUIRE(!buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == 0); + REQUIRE(buffer.remainingDataSize() == Buffer::BUFFER_SIZE); + REQUIRE(!buffer.consumed()); + } + } + + GIVEN("Buffer in non-default state") + { + buffer.produce(5); + buffer.consume(3); + + THEN("Reset restpres the default state") + { + buffer.reset(); + REQUIRE(buffer.head() == buffer._data.data()); + REQUIRE(buffer.tail() == buffer._data.data()); + REQUIRE(buffer.hasRemainingCapacity()); + REQUIRE(buffer.remainingCapacity() == Buffer::BUFFER_SIZE); + REQUIRE(buffer.remainingDataSize() == 0); + REQUIRE(buffer.consumed()); + } + } +} diff --git a/vsock-bridge/test/test_channel.cpp b/vsock-bridge/test/test_channel.cpp new file mode 100644 index 0000000..b5f8f46 --- /dev/null +++ b/vsock-bridge/test/test_channel.cpp @@ -0,0 +1,420 @@ +#include + +#include "catch.hpp" + +using namespace vsockio; + +static int mockIoAgain(int, void*, int) +{ + errno = EAGAIN; + return -1; +} + +static std::function mockIoSuccessOnce(int returnValue) +{ + bool called = false; + return [=] (int fd, void* data, int sz) mutable + { + if (!called) + { + called = true; + return returnValue; + } + else + { + return mockIoAgain(fd, data, sz); + } + }; +} + +static std::function mockIoMustNotCall(const std::string& message) +{ + return [=] (int, void*, int) { FAIL(message); return -1; }; +} + +static std::function mockIoError(int err) +{ + return [=] (int, void*, int) { errno = err; return -1; }; +} + +static int mockCloseSuccess(int) +{ + return 0; +} + +SCENARIO("DirectChannel between sockets establishing connections") +{ + SocketImpl saImpl(mockIoMustNotCall("read on sa"), mockIoMustNotCall("write on sa"), mockCloseSuccess); + SocketImpl sbImpl(mockIoMustNotCall("read on sb"), mockIoMustNotCall("write on sa"), mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + + GIVEN("Sockets are not connected") + { + THEN("No IO is possible") + { + channel.performIO(); + REQUIRE(!channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + } + } + + GIVEN("Both sockets connected, but no data is available") + { + sa.onConnected(); + sb.onConnected(); + + saImpl.read = sbImpl.read = mockIoAgain; + + THEN("No remaining IO after reads") + { + channel.performIO(); + REQUIRE(!channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + } + } +} + +SCENARIO("DirectChannel between sockets with one connected socket") +{ + SocketImpl saImpl(mockIoMustNotCall("read on sa"), mockIoMustNotCall("write on sa"), mockCloseSuccess); + SocketImpl sbImpl(mockIoMustNotCall("read on sb"), mockIoMustNotCall("write on sa"), mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + sa.onConnected(); + + GIVEN("Some data available on the connected socket") + { + saImpl.read = mockIoSuccessOnce(5); + + THEN("Read all and queue for writing") + { + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + + AND_THEN("Read socket has pending IO, but writer does not") + { + REQUIRE(sa.canReadWriteMore()); + REQUIRE(!sb.canReadWriteMore()); + } + } + } +} + +SCENARIO("DirectChannel between connected sockets") +{ + SocketImpl saImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + SocketImpl sbImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + sa.onConnected(); + sb.onConnected(); + + GIVEN("Some data available on one of the sockets but write is blocked on the other") + { + saImpl.read = mockIoSuccessOnce(5); + sbImpl.write = [&] (int fd, void* d, int sz) { REQUIRE(sz == 5); return mockIoAgain(fd, d, sz); }; + + THEN("Read all and queue for writing") + { + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + + AND_THEN("Read socket has pending IO, but writer does not") + { + REQUIRE(sa.canReadWriteMore()); + REQUIRE(!sb.canReadWriteMore()); + } + } + } + + GIVEN("Some data available on one of the sockets and some write succeeds on the other") + { + saImpl.read = mockIoSuccessOnce(5); + sbImpl.write = mockIoSuccessOnce(3); + + THEN("Read all and queue for writing") + { + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + + AND_THEN("Both sockets have pending IO") + { + REQUIRE(sa.canReadWriteMore()); + REQUIRE(!sb.canReadWriteMore()); + } + + AND_THEN("Remaining data is written") + { + saImpl.read = mockIoAgain; + sbImpl.write = mockIoSuccessOnce(2); + + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(!channel.canBeTerminated()); + + REQUIRE(!sa.canReadWriteMore()); + REQUIRE(sb.canReadWriteMore()); + } + } + } + + GIVEN("Write queue is full on one of the sockets") + { + saImpl.read = mockIoSuccessOnce(Buffer::BUFFER_SIZE); + channel.performIO(); + + THEN("No reads are performed afterwards") + { + saImpl.read = mockIoMustNotCall("sa read"); + channel.performIO(); + REQUIRE(!channel.canReadWriteMore()); + + AND_THEN("No reads are performed even after some of the data has been written out") + { + sbImpl.write = mockIoSuccessOnce(2); + channel.performIO(); + REQUIRE(!channel.canReadWriteMore()); + } + } + + THEN("Can resume reads after the buffer has been fully consumed") + { + saImpl.read = mockIoMustNotCall("sa read"); + sbImpl.write = mockIoSuccessOnce(Buffer::BUFFER_SIZE); + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(!sa.canReadWriteMore()); + REQUIRE(sb.canReadWriteMore()); + + saImpl.read = mockIoSuccessOnce(3); + channel.performIO(); + REQUIRE(channel.canReadWriteMore()); + REQUIRE(sa.canReadWriteMore()); + REQUIRE(!sb.canReadWriteMore()); + } + } + + GIVEN("Socket writes out full buffer of data") + { + saImpl.read = mockIoSuccessOnce(Buffer::BUFFER_SIZE); + channel.performIO(); + sbImpl.write = mockIoSuccessOnce(Buffer::BUFFER_SIZE); + channel.performIO(); + + THEN("Socket can write more data") + { + saImpl.read = mockIoSuccessOnce(5); + int bytesWritten = 0; + sbImpl.write = [&] (int fd, void* d, int sz) { bytesWritten = sz; return mockIoAgain(fd, d, sz); }; + channel.performIO(); + + REQUIRE(bytesWritten == 5); + } + } +} + +SCENARIO("DirectChannel - async connection") +{ + SocketImpl saImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + SocketImpl sbImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + sa.onConnected(); + + GIVEN("Second socket reports a successful connection") + { + sbImpl.write = mockIoSuccessOnce(0); + sb.checkConnected(); + + THEN("Second socket is connected") + { + REQUIRE(sb.connected()); + } + } +} + +SCENARIO("DirectChannel - orderly disconnects") +{ + SocketImpl saImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + SocketImpl sbImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + sa.onConnected(); + sb.onConnected(); + + GIVEN("A socket reports it is closed while no outstanding data") + { + sbImpl.read = mockIoSuccessOnce(0); + channel.performIO(); + + THEN("Second socket is disconnected") + { + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + REQUIRE(channel.canBeTerminated()); + } + } + + GIVEN("A socket reports it is closed while second socket is writing data out") + { + saImpl.read = mockIoSuccessOnce(10); + channel.performIO(); + saImpl.read = mockIoSuccessOnce(0); + channel.performIO(); + + THEN("First socket is closed and second socket is not") + { + REQUIRE(sa.closed()); + REQUIRE(!sb.closed()); + + AND_THEN("Second socket writes out some data and remains open") + { + sbImpl.write = mockIoSuccessOnce(6); + channel.performIO(); + + REQUIRE(sa.closed()); + REQUIRE(!sb.closed()); + + AND_THEN("Second socket writes out remaining data and both sockets are closed") + { + sbImpl.write = mockIoSuccessOnce(4); + channel.performIO(); + + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } + } + + THEN("Second socket input is closed") + { + sbImpl.read = mockIoMustNotCall("sb read"); + channel.performIO(); + + REQUIRE(sa.closed()); + REQUIRE(!sb.closed()); + } + } + + GIVEN("A socket reports it is closed while it is writing data out") + { + saImpl.read = mockIoSuccessOnce(10); + channel.performIO(); + sbImpl.read = mockIoSuccessOnce(0); + channel.performIO(); + + THEN("Both sockets are closed") + { + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } +} + +SCENARIO("DirectChannel - error conditions") +{ + SocketImpl saImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + SocketImpl sbImpl(mockIoAgain, mockIoAgain, mockCloseSuccess); + DirectChannel channel(1, std::make_unique(41, saImpl), std::make_unique(42, sbImpl)); + auto &sa = *channel._a; + auto &sb = *channel._b; + sa.onConnected(); + + GIVEN("Second socket reports a connection error") + { + sbImpl.write = mockIoError(ECONNREFUSED); + sb.checkConnected(); + + THEN("Both sockets are closed") + { + REQUIRE(channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } + + sb.onConnected(); + + GIVEN("Socket read fails") + { + saImpl.read = mockIoError(ECONNABORTED); + channel.performIO(); + + THEN("Both sockets are closed") + { + REQUIRE(channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } + + GIVEN("Second socket has data queued for writing") + { + saImpl.read = mockIoSuccessOnce(10); + channel.performIO(); + + AND_GIVEN("Reading the first second fails") + { + saImpl.read = mockIoError(ECONNABORTED); + channel.performIO(); + + THEN("Second socket enters draining mode") + { + REQUIRE(!channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(!sb.closed()); + + AND_THEN("Writes out all queued data and is closed") + { + sbImpl.write = mockIoSuccessOnce(10); + channel.performIO(); + + REQUIRE(channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } + } + } + + GIVEN("Socket write fails") + { + saImpl.read = mockIoSuccessOnce(10); + channel.performIO(); + sbImpl.write = mockIoError(ECONNABORTED); + channel.performIO(); + + THEN("Both sockets are closed") + { + REQUIRE(channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } + + GIVEN("Socket write fails while draining") + { + saImpl.read = mockIoSuccessOnce(10); + channel.performIO(); + saImpl.read = mockIoSuccessOnce(0); + channel.performIO(); + sbImpl.write = mockIoError(ECONNABORTED); + channel.performIO(); + + THEN("Both sockets are closed") + { + REQUIRE(channel.canBeTerminated()); + REQUIRE(sa.closed()); + REQUIRE(sb.closed()); + } + } +} diff --git a/vsock-bridge/test/test_threading.cpp b/vsock-bridge/test/test_threading.cpp new file mode 100644 index 0000000..f55ebbd --- /dev/null +++ b/vsock-bridge/test/test_threading.cpp @@ -0,0 +1,59 @@ +#include + +#include "catch.hpp" + +using namespace vsockio; + +SCENARIO("ThreadSafeQueue") +{ + GIVEN("A newly created queue") + { + ThreadSafeQueue q; + THEN("Dequeue returns empty") + { + REQUIRE(!q.dequeue()); + } + } + + GIVEN("A queue with starting items") + { + ThreadSafeQueue q; + q.enqueue(1); + q.enqueue(2); + + THEN("Can dequeue the first element") + { + REQUIRE(q.dequeue() == 1); + + AND_THEN("Can dequeue the second element") + { + REQUIRE(q.dequeue() == 2); + + AND_THEN("Cannot dequeue more") + { + REQUIRE(!q.dequeue()); + } + } + } + } + + GIVEN("A queue with starting items dequeued") + { + ThreadSafeQueue q; + q.enqueue(1); + q.enqueue(2); + q.dequeue(); + q.dequeue(); + + THEN("Can enqueue and dequeue another item") + { + q.enqueue(3); + REQUIRE(q.dequeue() == 3); + + AND_THEN("Cannot dequeue more") + { + REQUIRE(!q.dequeue()); + } + } + } +} diff --git a/vsock-bridge/test/testmain.cpp b/vsock-bridge/test/testmain.cpp index da8a9e7..79385db 100644 --- a/vsock-bridge/test/testmain.cpp +++ b/vsock-bridge/test/testmain.cpp @@ -1,28 +1,29 @@ -#define CATCH_CONFIG_MAIN +#include + +#define CATCH_CONFIG_RUNNER #include "catch.hpp" -#include -#include -#include -#include -#include -#include -#include +int main( int argc, char* argv[] ) { + + Logger::instance->setMinLevel(Logger::DEBUG); + Logger::instance->setStreamProvider(new StdoutLogger()); -#include -#include -#include -#include -#include "mocks.h" + int result = Catch::Session().run( argc, argv ); -using namespace vsockio; + // your clean-up... + return result; +} +#if 0 std::vector> ThreadPool::threads; thread_local MemoryArena* BufferManager::arena = new MemoryArena(); -TEST_CASE("Queue works", "[queue]") +TEST_CASE("ThreadSafeQueue", "[queue]") { + ThreadSafeQueue q; + REQUIRE_THAT(q.dequeue()) + UniquePtrQueue q; std::unique_ptr pNumbers[]{ @@ -532,4 +533,5 @@ TEST_CASE("Dispatcher", "[dispatcher]") terminateWorkerThreads(); REQUIRE(dest == source); -} \ No newline at end of file +} +#endif