Skip to content

Commit

Permalink
[coro_io] add support for cancellation (#887)
Browse files Browse the repository at this point in the history
* [coro_io] add support for cancellation

* fix

* fix ssl compile error

* remove useless dispatch

* fix timer user-after-free

* f

* f

* 1

* fix mem order

* fix

* fix mem order

* fix

* fix format
  • Loading branch information
poor-circle authored Jan 26, 2025
1 parent 56fbbba commit 9d4dfb5
Show file tree
Hide file tree
Showing 12 changed files with 615 additions and 259 deletions.
479 changes: 306 additions & 173 deletions include/ylt/coro_io/coro_io.hpp

Large diffs are not rendered by default.

36 changes: 31 additions & 5 deletions include/ylt/coro_io/io_context_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <vector>

#include "asio/dispatch.hpp"
#include "async_simple/Signal.h"
#ifdef __linux__
#include <pthread.h>
#include <sched.h>
Expand Down Expand Up @@ -104,11 +105,36 @@ class ExecutorWrapper : public async_simple::Executor {
}
void schedule(Func func, Duration dur, uint64_t hint,
async_simple::Slot *slot = nullptr) override {
auto timer = std::make_unique<asio::steady_timer>(executor_, dur);
auto tm = timer.get();
tm->async_wait([fn = std::move(func), timer = std::move(timer)](auto ec) {
fn();
});
auto timer =
std::make_shared<std::pair<asio::steady_timer, std::atomic<bool>>>(
asio::steady_timer{executor_, dur}, false);
if (!slot) {
timer->first.async_wait([fn = std::move(func), timer](const auto &ec) {
fn();
});
}
else {
if (!async_simple::signalHelper{async_simple::SignalType::Terminate}
.tryEmplace(
slot, [timer](auto signalType, auto *signal) mutable {
if (bool expected = false;
!timer->second.compare_exchange_strong(
expected, true, std::memory_order_acq_rel)) {
timer->first.cancel();
}
})) {
asio::dispatch(timer->first.get_executor(), func);
}
else {
timer->first.async_wait([fn = std::move(func), timer](const auto &ec) {
fn();
});
if (bool expected = false; !timer->second.compare_exchange_strong(
expected, true, std::memory_order_acq_rel)) {
timer->first.cancel();
}
}
}
}
};

Expand Down
2 changes: 2 additions & 0 deletions include/ylt/metric/summary_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class summary_impl {
if (piece) {
if constexpr (inc_order) {
for (int j = 0; j < piece->size(); ++j) {
// tsan check data race here is expected. stat dont need to be very
// strict. we allow old value.
auto value = (*piece)[j].load(std::memory_order_relaxed);
if (value) {
result.emplace_back(get_ordered_index(i * piece_size + j), value);
Expand Down
11 changes: 6 additions & 5 deletions include/ylt/thirdparty/async_simple/Signal.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#define ASYNC_SIMPLE_SIGNAL_H

#ifndef ASYNC_SIMPLE_USE_MODULES

#include <assert.h>
#include <any>
#include <atomic>
Expand Down Expand Up @@ -205,14 +204,16 @@ class Slot {
"we dont allow emplace an empty signal handler");
logicAssert(std::popcount(static_cast<uint64_t>(type)) == 1,
"It's not allow to emplace for multiple signals");
// trigger-once signal has already been triggered
auto handler = std::make_unique<detail::SignalSlotSharedState::Handler>(
std::forward<Args>(args)...);
auto oldHandlerPtr = loadHandler<true>(type);
// check trigger-once signal has already been triggered
// if signal has already been triggered, return false
if (!detail::SignalSlotSharedState::isMultiTriggerSignal(type) &&
(signal()->state() & type)) {
return false;
}
auto handler = std::make_unique<detail::SignalSlotSharedState::Handler>(
std::forward<Args>(args)...);
auto oldHandlerPtr = loadHandler<true>(type);
// if signal triggered later, we will found it by cas failed.
auto oldHandler = oldHandlerPtr->load(std::memory_order_acquire);
if (oldHandler ==
&detail::SignalSlotSharedState::HandlerManager::emittedTag) {
Expand Down
87 changes: 41 additions & 46 deletions include/ylt/thirdparty/async_simple/coro/Collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ struct CollectAnyAwaiter {
_slot, [c = continuation, e = event, size = input.size()](
SignalType type, Signal*) mutable {
auto count = e->downCount();
if (count > size + 1) {
c.resume();
if (count == size + 1) {
c.resume();
}
})) { // has canceled
return false;
Expand All @@ -186,14 +186,14 @@ struct CollectAnyAwaiter {
assert(e != nullptr);
auto count = e->downCount();
// n+1: n coro + 1 cancel handler
if (count > size + 1) {
_result = std::make_unique<ResultType>();
_result->_idx = i;
_result->_value = std::move(result);
if (auto ptr = local->getSlot(); ptr) {
ptr->signal()->emit(_SignalType);
}
c.resume();
if (count == size + 1) {
_result = std::make_unique<ResultType>();
_result->_idx = i;
_result->_value = std::move(result);
if (auto ptr = local->getSlot(); ptr) {
ptr->signal()->emit(_SignalType);
}
c.resume();
}
});
} // end for
Expand Down Expand Up @@ -268,8 +268,8 @@ struct CollectAnyVariadicAwaiter {
_slot, [c = continuation, e = event](SignalType type,
Signal*) mutable {
auto count = e->downCount();
if (count > std::tuple_size<InputType>() + 1) {
c.resume();
if (count == std::tuple_size<InputType>() + 1) {
c.resume();
}
})) { // has canceled
return false;
Expand All @@ -290,13 +290,13 @@ struct CollectAnyVariadicAwaiter {
res) mutable {
auto count = e->downCount();
// n+1: n coro + 1 cancel handler
if (count > std::tuple_size<InputType>() + 1) {
_result = std::make_unique<ResultType>(
std::in_place_index_t<index>(), std::move(res));
if (auto ptr = local->getSlot(); ptr) {
ptr->signal()->emit(_SignalType);
}
c.resume();
if (count == std::tuple_size<InputType>() + 1) {
_result = std::make_unique<ResultType>(
std::in_place_index_t<index>(), std::move(res));
if (auto ptr = local->getSlot(); ptr) {
ptr->signal()->emit(_SignalType);
}
c.resume();
}
});
}(),
Expand Down Expand Up @@ -388,15 +388,19 @@ struct CollectAllAwaiter {
_slot->chainedSignal(_signal.get());

auto executor = promise_type._executor;
for (size_t i = 0; i < _input.size(); ++i) {
auto& exec = _input[i]._coro.promise()._executor;
if (exec == nullptr) {
exec = executor;
}
std::unique_ptr<LazyLocalBase> local;
local = std::make_unique<LazyLocalBase>(_signal.get());
_input[i]._coro.promise()._lazy_local = local.get();
auto&& func = [this, i, local = std::move(local)]() mutable {

_event.setAwaitingCoro(continuation);
auto size = _input.size();
for (size_t i = 0; i < size; ++i) {
auto& exec = _input[i]._coro.promise()._executor;
if (exec == nullptr) {
exec = executor;
}
std::unique_ptr<LazyLocalBase> local;
local = std::make_unique<LazyLocalBase>(_signal.get());
_input[i]._coro.promise()._lazy_local = local.get();
auto&& func =
[this, i, local = std::move(local)]() mutable {
_input[i].start([this, i, local = std::move(local)](
Try<ValueType>&& result) {
_output[i] = std::move(result);
Expand All @@ -412,20 +416,15 @@ struct CollectAllAwaiter {
awaitingCoro.resume();
}
});
};
if (Para == true && _input.size() > 1) {
if (exec != nullptr)
AS_LIKELY {
exec->schedule_move_only(std::move(func));
continue;
}
}
func();
}
_event.setAwaitingCoro(continuation);
auto awaitingCoro = _event.down();
if (awaitingCoro) {
awaitingCoro.resume();
};
if (Para == true && _input.size() > 1) {
if (exec != nullptr)
AS_LIKELY {
exec->schedule_move_only(std::move(func));
continue;
}
}
func();
}
}
inline auto await_resume() { return std::move(_output); }
Expand Down Expand Up @@ -602,10 +601,6 @@ struct CollectAllVariadicAwaiter {
}
}(std::get<index>(_inputs), std::get<index>(_results)),
...);

if (auto awaitingCoro = _event.down(); awaitingCoro) {
awaitingCoro.resume();
}
}

void await_suspend(std::coroutine_handle<> continuation) {
Expand Down
16 changes: 8 additions & 8 deletions include/ylt/thirdparty/async_simple/coro/CountEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ namespace detail {
// The last 'down' will resume the awaiting coroutine on this event.
class CountEvent {
public:
CountEvent(size_t count) : _count(count + 1) {}
CountEvent(const CountEvent&) = delete;
CountEvent(CountEvent&& other)
: _count(other._count.exchange(0, std::memory_order_relaxed)),
_awaitingCoro(std::exchange(other._awaitingCoro, nullptr)) {}
CountEvent(size_t count) : _count(count) {}
CountEvent(const CountEvent&) = delete;
CountEvent(CountEvent&& other)
: _count(other._count.exchange(0, std::memory_order_relaxed)),
_awaitingCoro(std::exchange(other._awaitingCoro, nullptr)) {}

[[nodiscard]] CoroHandle<> down(size_t n = 1) {
std::size_t oldCount;
return down(oldCount, n);
[[nodiscard]] CoroHandle<> down(size_t n = 1) {
std::size_t oldCount;
return down(oldCount, n);
}
[[nodiscard]] CoroHandle<> down(size_t& oldCount, std::size_t n) {
// read acquire and write release, _awaitingCoro store can not be
Expand Down
1 change: 1 addition & 0 deletions src/coro_io/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_executable(coro_io_test
test_client_pool.cpp
test_rate_limiter.cpp
test_coro_channel.cpp
test_cancel.cpp
main.cpp
)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_SYSTEM_NAME MATCHES "Windows") # mingw-w64
Expand Down
Loading

0 comments on commit 9d4dfb5

Please sign in to comment.