diff --git a/src/zk/client_tests.cpp b/src/zk/client_tests.cpp new file mode 100644 index 0000000..6d75804 --- /dev/null +++ b/src/zk/client_tests.cpp @@ -0,0 +1,40 @@ +#include + +#include "client.hpp" + +namespace zk +{ + +class client_tests : + public server::single_server_fixture +{ }; + +GTEST_TEST_F(client_tests, watch_close) +{ + client c = get_connected_client(); + auto watch = c.watch("/").get(); + + c.close(); + + // watch should be triggered with session closed + auto ev = watch.next().get(); + CHECK_EQ(ev.type(), event_type::session); + CHECK_EQ(ev.state(), state::closed); +} + +class stopping_client_tests : + public server::server_fixture +{ }; + +GTEST_TEST_F(stopping_client_tests, watch_server_stop) +{ + client c = get_connected_client(); + auto watch = c.watch("/").get(); + + this->stop_server(true); + + auto ev = watch.next().get(); + CHECK_EQ(ev.type(), event_type::session); +} + +} diff --git a/src/zk/connection_zk.cpp b/src/zk/connection_zk.cpp index e361623..380ac08 100644 --- a/src/zk/connection_zk.cpp +++ b/src/zk/connection_zk.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -162,6 +163,104 @@ connection_zk::~connection_zk() noexcept close(); } +class connection_zk::watcher +{ +public: + watcher() : + _event_delivered(false) + { } + + virtual ~watcher() noexcept {} + + virtual void deliver_event(event ev) + { + if (!_event_delivered.exchange(true, std::memory_order_relaxed)) + { + _event_promise.set_value(std::move(ev)); + } + } + + future get_event_future() + { + return _event_promise.get_future(); + } + +protected: + std::atomic _event_delivered; + promise _event_promise; +}; + +template +class connection_zk::basic_watcher : + public connection_zk::watcher +{ +public: + basic_watcher() : + _data_delivered(false) + { } + + future get_data_future() + { + return _data_promise.get_future(); + } + + virtual void deliver_event(event ev) override + { + if (!_data_delivered.load(std::memory_order_relaxed)) + { + deliver_data(nullopt, get_exception_ptr_of(error_code::closing)); + } + + watcher::deliver_event(std::move(ev)); + } + + void deliver_data(optional data, std::exception_ptr ex_ptr) + { + if (!_data_delivered.exchange(true, std::memory_order_relaxed)) + { + if (ex_ptr) + { + _data_promise.set_exception(std::move(ex_ptr)); + } + else + { + _data_promise.set_value(std::move(*data)); + } + } + } + +private: + std::atomic _data_delivered; + promise _data_promise; +}; + +std::shared_ptr connection_zk::try_extract_watch(ptr addr) +{ + std::unique_lock ax(_watches_protect); + auto iter = _watches.find(addr); + if (iter != _watches.end()) + return _watches.extract(iter).mapped(); + else + return nullptr; +} + +static ptr connection_from_context(ptr zh) +{ + return (ptr) zoo_get_context(zh); +} + +void connection_zk::deliver_watch(ptr zh, + int type_in, + int state_in, + ptr path [[gnu::unused]], + ptr proms_in + ) +{ + auto& self = *connection_from_context(zh); + if (auto watcher = self.try_extract_watch(proms_in)) + watcher->deliver_event(event(event_from_raw(type_in), state_from_raw(state_in))); +} + void connection_zk::close() { if (_handle) @@ -171,6 +270,13 @@ void connection_zk::close() throw_error(err); _handle = nullptr; + + // Deliver a session event as if there was a close. + std::unique_lock ax(_watches_protect); + auto l_watches = std::move(_watches); + ax.unlock(); + for (const auto& pair : l_watches) + pair.second->deliver_event(event(event_type::session, zk::state::closed)); } } @@ -213,59 +319,55 @@ future connection_zk::get(string_view path) }); } -future connection_zk::watch(string_view path) +class connection_zk::data_watcher : + public connection_zk::basic_watcher { - using watch_promises = std::pair, std::promise>; +public: + static void deliver_raw(int rc_in, + ptr data, + int data_sz, + ptr pstat, + ptr self_in + ) noexcept + { + auto& self = *static_cast>(const_cast>(self_in)); + auto rc = error_code_from_raw(rc_in); - ::data_completion_t data_callback = - [] (int rc_in, ptr data, int data_sz, ptr pstat, ptr prom_in) noexcept + if (rc == error_code::ok) { - std::unique_ptr prom((ptr) prom_in); - auto rc = error_code_from_raw(rc_in); - if (rc == error_code::ok) - { - prom->first.set_value(watch_result(get_result(buffer(data, data + data_sz), stat_from_raw(*pstat)), - prom->second.get_future() - ) - ); - // Since there was no error, we know the watch will be triggered - prom.release(); - } - else - { - prom->first.set_exception(get_exception_ptr_of(rc)); - } - }; - - ::watcher_fn watch_callback = - [] (ptr, int type_in, int state_in, ptr, ptr proms_in) + self.deliver_data(watch_result(get_result(buffer(data, data + data_sz), stat_from_raw(*pstat)), + self.get_event_future() + ), + std::exception_ptr() + ); + } + else { - std::unique_ptr prom(static_cast>(proms_in)); - prom->second.set_value(event(event_from_raw(type_in), state_from_raw(state_in))); - }; + self.deliver_data(nullopt, get_exception_ptr_of(rc)); + } + } +}; +future connection_zk::watch(string_view path) +{ return with_str(path, [&] (ptr path) { - auto ppromises = std::make_unique(); - auto rc = error_code_from_raw(::zoo_awget(_handle, - path, - watch_callback, - ppromises.get(), - data_callback, - ppromises.get() - ) - ); + std::unique_lock ax(_watches_protect); + auto watcher = std::make_shared(); + auto rc = error_code_from_raw(::zoo_awget(_handle, + path, + deliver_watch, + watcher.get(), + data_watcher::deliver_raw, + watcher.get() + ) + ); if (rc == error_code::ok) - { - auto f = ppromises->first.get_future(); - ppromises.release(); - return f; - } + _watches.emplace(watcher.get(), watcher); else - { - ppromises->first.set_exception(get_exception_ptr_of(rc)); - return ppromises->first.get_future(); - } + watcher->deliver_data(nullopt, get_exception_ptr_of(rc)); + + return watcher->get_data_future(); }); } @@ -317,67 +419,60 @@ future connection_zk::get_children(string_view path) }); } -future connection_zk::watch_children(string_view path) +class connection_zk::child_watcher : + public connection_zk::basic_watcher { - using watch_promises = std::pair, std::promise>; +public: + static void deliver_raw(int rc_in, + ptr strings_in, + ptr stat_in, + ptr prom_in + ) noexcept + { + auto& self = *static_cast>(const_cast>(prom_in)); + auto rc = error_code_from_raw(rc_in); - ::strings_stat_completion_t data_callback = - [] (int rc_in, - ptr strings_in, - ptr stat_in, - ptr prom_in - ) + try { - std::unique_ptr prom((ptr) prom_in); - auto rc = error_code_from_raw(rc_in); - try - { - if (rc != error_code::ok) - throw_error(rc); - - prom->first.set_value(watch_children_result(get_children_result(string_vector_from_raw(*strings_in), - stat_from_raw(*stat_in) - ), - prom->second.get_future() - ) - ); - prom.release(); - } - catch (...) - { - prom->first.set_exception(std::current_exception()); - } - }; + if (rc != error_code::ok) + throw_error(rc); - ::watcher_fn watch_callback = - [] (ptr, int type_in, int state_in, ptr, ptr proms_in) + self.deliver_data(watch_children_result(get_children_result(string_vector_from_raw(*strings_in), + stat_from_raw(*stat_in) + ), + self.get_event_future() + ), + std::exception_ptr() + ); + } + catch (...) { - std::unique_ptr proms(static_cast>(proms_in)); - proms->second.set_value(event(event_from_raw(type_in), state_from_raw(state_in))); - }; + self.deliver_data(nullopt, std::current_exception()); + } + } +}; + +future connection_zk::watch_children(string_view path) +{ return with_str(path, [&] (ptr path) { - auto ppromises = std::make_unique(); - auto rc = error_code_from_raw(::zoo_awget_children2(_handle, - path, - watch_callback, - ppromises.get(), - data_callback, - ppromises.get() - ) - ); + std::unique_lock ax(_watches_protect); + auto watcher = std::make_shared(); + auto rc = error_code_from_raw(::zoo_awget_children2(_handle, + path, + deliver_watch, + watcher.get(), + child_watcher::deliver_raw, + watcher.get() + ) + ); if (rc == error_code::ok) - { - auto f = ppromises->first.get_future(); - ppromises.release(); - return f; - } + _watches.emplace(watcher.get(), watcher); else - { - ppromises->first.set_exception(get_exception_ptr_of(rc)); - return ppromises->first.get_future(); - } + watcher->deliver_data(nullopt, get_exception_ptr_of(rc)); + + return watcher->get_data_future(); }); } @@ -414,66 +509,54 @@ future connection_zk::exists(string_view path) }); } -future connection_zk::watch_exists(string_view path) +class connection_zk::exists_watcher : + public connection_zk::basic_watcher { - using watch_promises = std::pair, std::promise>; +public: + static void deliver_raw(int rc_in, ptr stat_in, ptr self_in) + { + auto& self = *static_cast>(const_cast>(self_in)); + auto rc = error_code_from_raw(rc_in); - ::stat_completion_t data_callback = - [] (int rc_in, ptr stat_in, ptr proms_in) + if (rc == error_code::ok) { - std::unique_ptr proms((ptr) proms_in); - auto rc = error_code_from_raw(rc_in); - if (rc == error_code::ok) - { - proms->first.set_value(watch_exists_result(exists_result(stat_from_raw(*stat_in)), - proms->second.get_future() - ) - ); - proms.release(); - } - else if (rc == error_code::no_node) - { - proms->first.set_value(watch_exists_result(exists_result(nullopt), - proms->second.get_future() - ) - ); - proms.release(); - } - else - { - proms->first.set_exception(get_exception_ptr_of(rc)); - } - }; - - ::watcher_fn watch_callback = - [] (ptr, int type_in, int state_in, ptr, ptr proms_in) + self.deliver_data(watch_exists_result(exists_result(stat_from_raw(*stat_in)), self.get_event_future()), + std::exception_ptr() + ); + } + else if (rc == error_code::no_node) { - std::unique_ptr proms(static_cast>(proms_in)); - proms->second.set_value(event(event_from_raw(type_in), state_from_raw(state_in))); - }; + self.deliver_data(watch_exists_result(exists_result(nullopt), self.get_event_future()), + std::exception_ptr() + ); + } + else + { + self.deliver_data(nullopt, get_exception_ptr_of(rc)); + } + } +}; +future connection_zk::watch_exists(string_view path) +{ return with_str(path, [&] (ptr path) { - auto ppromises = std::make_unique(); - auto rc = error_code_from_raw(::zoo_awexists(_handle, - path, - watch_callback, - ppromises.get(), - data_callback, - ppromises.get() - ) - ); + std::unique_lock ax(_watches_protect); + auto watcher = std::make_shared(); + auto rc = error_code_from_raw(::zoo_awexists(_handle, + path, + deliver_watch, + watcher.get(), + exists_watcher::deliver_raw, + watcher.get() + ) + ); if (rc == error_code::ok) - { - auto f = ppromises->first.get_future(); - ppromises.release(); - return f; - } + _watches.emplace(watcher.get(), watcher); else - { - ppromises->first.set_exception(get_exception_ptr_of(rc)); - return ppromises->first.get_future(); - } + watcher->deliver_data(nullopt, get_exception_ptr_of(rc)); + + return watcher->get_data_future(); }); } @@ -899,6 +982,7 @@ void connection_zk::on_session_event_raw(ptr handle [[gnu::unus auto ev = event_from_raw(ev_type); auto st = state_from_raw(state); auto path = string_view(path_ptr); + if (ev != event_type::session) { // TODO: Remove this usage of std::cerr diff --git a/src/zk/connection_zk.hpp b/src/zk/connection_zk.hpp index d3d923f..0f8ff6b 100644 --- a/src/zk/connection_zk.hpp +++ b/src/zk/connection_zk.hpp @@ -3,6 +3,9 @@ #include #include +#include +#include +#include #include "connection.hpp" #include "string_view.hpp" @@ -66,8 +69,31 @@ class connection_zk final : ptr watcher_ctx ) noexcept; + using watch_function = void (*)(ptr, int type_in, int state_in, ptr, ptr); + + class watcher; + + template + class basic_watcher; + + class data_watcher; + + class child_watcher; + + class exists_watcher; + + /** Erase the watch tracker for the watch with the value \a p. + * + * \returns \c true if it was deleted (the watch should be delivered); \c false if \a p was not in the list. + **/ + std::shared_ptr try_extract_watch(ptr p); + + static void deliver_watch(ptr zh, int type_in, int state_in, ptr, ptr proms_in); + private: - ptr _handle; + ptr _handle; + std::unordered_map, std::shared_ptr> _watches; + mutable std::mutex _watches_protect; }; /** \} **/ diff --git a/src/zk/server/server_tests.cpp b/src/zk/server/server_tests.cpp index 5077632..1a5cb7c 100644 --- a/src/zk/server/server_tests.cpp +++ b/src/zk/server/server_tests.cpp @@ -38,6 +38,11 @@ client server_fixture::get_connected_client() const return client(client::connect(get_connection_string()).get()); } +void server_fixture::stop_server(bool wait_for_stop) +{ + _server->shutdown(wait_for_stop); +} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // single_server_fixture // //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/zk/server/server_tests.hpp b/src/zk/server/server_tests.hpp index b528f7e..75f0dda 100644 --- a/src/zk/server/server_tests.hpp +++ b/src/zk/server/server_tests.hpp @@ -24,6 +24,8 @@ class server_fixture : client get_connected_client() const; + void stop_server(bool wait_for_stop = true); + private: std::shared_ptr _server; std::string _conn_string;