Skip to content

Commit 55929b1

Browse files
committed
Experimentation on cancellable awaiter.
1 parent 1fd5c7e commit 55929b1

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ if (BUILD_BROTLI)
245245
endif (BUILD_BROTLI)
246246

247247
set(DROGON_SOURCES
248+
lib/src/coroutine.cc
248249
lib/src/AOPAdvice.cc
249250
lib/src/AccessLogger.cc
250251
lib/src/CacheFile.cc

lib/inc/drogon/utils/coroutine.h

+55-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@
2727
#include <type_traits>
2828
#include <optional>
2929

30+
namespace drogon
31+
{
32+
struct CancelHandle;
33+
using CancelHandlePtr = std::shared_ptr<CancelHandle>;
34+
35+
struct CancelHandle
36+
{
37+
static CancelHandlePtr create();
38+
39+
virtual void cancel() = 0;
40+
virtual bool isCancelRequested() = 0;
41+
virtual void registerCancelCallback(std::function<void()> callback) = 0;
42+
};
43+
44+
class TaskCancelledException final : public std::runtime_error
45+
{
46+
public:
47+
using std::runtime_error::runtime_error;
48+
};
49+
} // namespace drogon
50+
3051
namespace drogon
3152
{
3253
namespace internal
@@ -596,6 +617,30 @@ struct [[nodiscard]] TimerAwaiter : CallbackAwaiter<void>
596617
double delay_;
597618
};
598619

620+
struct [[nodiscard]] CancellableTimeAwaiter : CallbackAwaiter<void>
621+
{
622+
CancellableTimeAwaiter(trantor::EventLoop *loop,
623+
const std::chrono::duration<double> &delay,
624+
CancelHandlePtr cancelHandle)
625+
: CancellableTimeAwaiter(loop, delay.count(), std::move(cancelHandle))
626+
{
627+
}
628+
629+
CancellableTimeAwaiter(trantor::EventLoop *loop,
630+
double delay,
631+
CancelHandlePtr cancelHandle)
632+
: loop_(loop), delay_(delay), cancelHandle_(std::move(cancelHandle))
633+
{
634+
}
635+
636+
void await_suspend(std::coroutine_handle<> handle);
637+
638+
private:
639+
trantor::EventLoop *loop_;
640+
double delay_;
641+
CancelHandlePtr cancelHandle_;
642+
};
643+
599644
struct [[nodiscard]] LoopAwaiter : CallbackAwaiter<void>
600645
{
601646
LoopAwaiter(trantor::EventLoop *workLoop,
@@ -684,6 +729,15 @@ inline internal::TimerAwaiter sleepCoro(trantor::EventLoop *loop,
684729
return {loop, delay};
685730
}
686731

732+
inline internal::CancellableTimeAwaiter sleepCoro(
733+
trantor::EventLoop *loop,
734+
double delay,
735+
CancelHandlePtr cancelHandle) noexcept
736+
{
737+
assert(loop);
738+
return {loop, delay, std::move(cancelHandle)};
739+
}
740+
687741
inline internal::LoopAwaiter queueInLoopCoro(
688742
trantor::EventLoop *workLoop,
689743
std::function<void()> taskFunc,
@@ -749,7 +803,7 @@ void async_run(Coro &&coro)
749803

750804
/**
751805
* @brief returns a function that calls a coroutine
752-
* @param coro A coroutine that is awaitable
806+
* @param Coro A coroutine that is awaitable
753807
*/
754808
template <typename Coro>
755809
std::function<void()> async_func(Coro &&coro)

lib/src/coroutine.cc

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//
2+
// Created by wanchen.he on 2023/12/29.
3+
//
4+
#ifdef __cpp_impl_coroutine
5+
#include <drogon/utils/coroutine.h>
6+
7+
namespace drogon
8+
{
9+
class CancelHandleImpl : public CancelHandle
10+
{
11+
public:
12+
CancelHandleImpl() = default;
13+
14+
void cancel() override
15+
{
16+
std::function<void()> handle;
17+
{
18+
std::lock_guard<std::mutex> lock(mutex_);
19+
cancelRequested_ = true;
20+
handle = std::move(cancelHandle_);
21+
}
22+
if (handle)
23+
handle();
24+
}
25+
26+
bool isCancelRequested() override
27+
{
28+
std::lock_guard<std::mutex> lock(mutex_);
29+
return cancelRequested_;
30+
}
31+
32+
void registerCancelCallback(std::function<void()> callback) override
33+
{
34+
}
35+
36+
void setCancelHandle(std::function<void()> handle)
37+
{
38+
bool cancelled{false};
39+
{
40+
std::lock_guard<std::mutex> lock(mutex_);
41+
if (cancelRequested_)
42+
{
43+
cancelled = true;
44+
}
45+
else
46+
{
47+
cancelHandle_ = std::move(handle);
48+
}
49+
}
50+
if (cancelled)
51+
{
52+
handle();
53+
}
54+
}
55+
56+
private:
57+
std::mutex mutex_;
58+
bool cancelRequested_{false};
59+
60+
std::shared_ptr<std::atomic_bool> flagPtr_;
61+
std::function<void()> cancelHandle_;
62+
};
63+
64+
CancelHandlePtr CancelHandle::create()
65+
{
66+
return std::make_shared<CancelHandleImpl>();
67+
}
68+
69+
void internal::CancellableTimeAwaiter::await_suspend(
70+
std::coroutine_handle<> handle)
71+
{
72+
auto execFlagPtr = std::make_shared<std::atomic_bool>(false);
73+
if (cancelHandle_)
74+
{
75+
static_cast<CancelHandleImpl *>(cancelHandle_.get())
76+
->setCancelHandle([this, handle, execFlagPtr, loop = loop_]() {
77+
if (!execFlagPtr->exchange(true))
78+
{
79+
setException(std::make_exception_ptr(
80+
TaskCancelledException("Task cancelled")));
81+
loop->queueInLoop([handle]() { handle.resume(); });
82+
return;
83+
}
84+
});
85+
}
86+
loop_->runAfter(delay_, [handle, execFlagPtr = std::move(execFlagPtr)]() {
87+
if (!execFlagPtr->exchange(true))
88+
{
89+
handle.resume();
90+
}
91+
});
92+
}
93+
} // namespace drogon
94+
95+
#endif

lib/tests/unittests/CoroutineTest.cc

+34
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,37 @@ DROGON_TEST(SwitchThread)
212212
sync_wait(switch_thread());
213213
thread.wait();
214214
}
215+
216+
DROGON_TEST(Cancellation)
217+
{
218+
using namespace drogon::internal;
219+
220+
trantor::EventLoopThread thread; // helper thread
221+
thread.run();
222+
223+
auto testCancelTask = [TEST_CTX, loop = thread.getLoop()]() -> Task<> {
224+
auto cancelHandle = CancelHandle::create();
225+
226+
// wait coro for 10 seconds, but cancel after 1 second
227+
loop->runAfter(1, [cancelHandle]() { cancelHandle->cancel(); });
228+
229+
int64_t start = time(nullptr);
230+
try
231+
{
232+
LOG_INFO << "Waiting for 10 seconds...";
233+
co_await sleepCoro(loop, 10, cancelHandle);
234+
CHECK(false); // should not reach here
235+
}
236+
catch (const TaskCancelledException &ex)
237+
{
238+
int64_t waitTime = time(nullptr) - start;
239+
CHECK(waitTime < 2);
240+
LOG_INFO << "Oops... only waited for " << waitTime << " second(s)";
241+
}
242+
243+
loop->quit();
244+
};
245+
246+
sync_wait(testCancelTask());
247+
thread.wait();
248+
}

0 commit comments

Comments
 (0)