From f24e917d9cf2146bdf754feb2c2dee8a12469941 Mon Sep 17 00:00:00 2001 From: Klemens Morgenstern Date: Fri, 27 Jun 2025 14:12:24 +0800 Subject: [PATCH] channel handles close & cancel can occur before await_suspend. --- include/boost/cobalt/channel.hpp | 13 ++-- include/boost/cobalt/impl/channel.hpp | 38 +++++++++++- src/channel.cpp | 23 ++++++- test/channel.cpp | 88 +++++++++++++++++++++++++++ test/test.hpp | 43 +++++++++++++ 5 files changed, 196 insertions(+), 9 deletions(-) diff --git a/include/boost/cobalt/channel.hpp b/include/boost/cobalt/channel.hpp index 23d779f..768a7fc 100644 --- a/include/boost/cobalt/channel.hpp +++ b/include/boost/cobalt/channel.hpp @@ -96,7 +96,7 @@ struct channel } struct cancel_impl; - bool await_ready() { return !chn->buffer_.empty(); } + bool await_ready() { return !chn->buffer_.empty() || chn->is_closed_; } template BOOST_COBALT_MSVC_NOINLINE std::coroutine_handle await_suspend(std::coroutine_handle h); @@ -140,7 +140,7 @@ struct channel struct cancel_impl; - bool await_ready() { return !chn->buffer_.full(); } + bool await_ready() { return !chn->buffer_.full() || chn->is_closed_; } template BOOST_COBALT_MSVC_NOINLINE std::coroutine_handle await_suspend(std::coroutine_handle h); @@ -257,9 +257,12 @@ struct channel } struct cancel_impl; - bool await_ready() { return (chn->n_ > 0); } - template + bool await_ready() + { + return (chn->n_ > 0) || chn->is_closed_; + } + template BOOST_COBALT_MSVC_NOINLINE std::coroutine_handle await_suspend(std::coroutine_handle h); BOOST_COBALT_DECL void await_resume(); @@ -297,7 +300,7 @@ struct channel struct cancel_impl; bool await_ready() { - return chn->n_ < chn->limit_; + return chn->n_ < chn->limit_ || chn->is_closed_; } template diff --git a/include/boost/cobalt/impl/channel.hpp b/include/boost/cobalt/impl/channel.hpp index caec65f..18ab98b 100644 --- a/include/boost/cobalt/impl/channel.hpp +++ b/include/boost/cobalt/impl/channel.hpp @@ -96,6 +96,9 @@ template template std::coroutine_handle channel::read_op::await_suspend(std::coroutine_handle h) { + if (cancelled) + return h; // already interrupted. + if constexpr (requires (Promise p) {p.get_cancellation_slot();}) if ((cancel_slot = h.promise().get_cancellation_slot()).is_connected()) cancel_slot.emplace(this); @@ -162,8 +165,17 @@ system::result channel::read_op::await_resume(const struct as_result_tag & if (cancel_slot.is_connected()) cancel_slot.clear(); + if (chn->is_closed_ && chn->buffer_.empty() && !direct) + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::broken_pipe, &loc}; + } + if (cancelled) - return {system::in_place_error, asio::error::operation_aborted}; + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::operation_aborted, &loc}; + } T value = chn->buffer_.empty() ? std::move(*direct) : std::move(chn->buffer_.front()); if (!chn->buffer_.empty()) @@ -207,6 +219,10 @@ template template std::coroutine_handle channel::write_op::await_suspend(std::coroutine_handle h) { + if (cancelled) + return h; // already interrupted. + + if constexpr (requires (Promise p) {p.get_cancellation_slot();}) if ((cancel_slot = h.promise().get_cancellation_slot()).is_connected()) cancel_slot.emplace(this); @@ -263,8 +279,19 @@ system::result channel::write_op::await_resume(const struct as_result_ { if (cancel_slot.is_connected()) cancel_slot.clear(); + + if (chn->is_closed_) + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::broken_pipe, &loc}; + } + + if (cancelled) - return {system::in_place_error, asio::error::operation_aborted}; + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::operation_aborted, &loc}; + } if (!direct) { @@ -326,6 +353,10 @@ struct channel::write_op::cancel_impl template std::coroutine_handle channel::read_op::await_suspend(std::coroutine_handle h) { + + if (cancelled) + return h; // already interrupted. + if constexpr (requires (Promise p) {p.get_cancellation_slot();}) if ((cancel_slot = h.promise().get_cancellation_slot()).is_connected()) cancel_slot.emplace(this); @@ -362,6 +393,9 @@ std::coroutine_handle channel::read_op::await_suspend(std::coroutine template std::coroutine_handle channel::write_op::await_suspend(std::coroutine_handle h) { + if (cancelled) + return h; // already interrupted. + if constexpr (requires (Promise p) {p.get_cancellation_slot();}) if ((cancel_slot = h.promise().get_cancellation_slot()).is_connected()) cancel_slot.emplace(this); diff --git a/src/channel.cpp b/src/channel.cpp index fa4c8dd..a9974c8 100644 --- a/src/channel.cpp +++ b/src/channel.cpp @@ -49,8 +49,17 @@ system::result channel::read_op::await_resume(const struct as_resul if (cancel_slot.is_connected()) cancel_slot.clear(); + if (chn->is_closed_) + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::broken_pipe, &loc}; + } + if (cancelled) - return {system::in_place_error, asio::error::operation_aborted}; + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::operation_aborted, &loc}; + } if (!direct) chn->n_--; @@ -83,8 +92,18 @@ system::result channel::write_op::await_resume(const struct as_resul { if (cancel_slot.is_connected()) cancel_slot.clear(); + + if (chn->is_closed_) + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::broken_pipe, &loc}; + } + if (cancelled) - return {system::in_place_error, asio::error::operation_aborted}; + { + constexpr static boost::source_location loc{BOOST_CURRENT_LOCATION}; + return {system::in_place_error, asio::error::operation_aborted, &loc}; + } if (!direct) chn->n_++; diff --git a/test/channel.cpp b/test/channel.cpp index d8b2156..4486125 100644 --- a/test/channel.cpp +++ b/test/channel.cpp @@ -393,4 +393,92 @@ CO_TEST_CASE(interrupt_void_1) BOOST_CHECK(rl == 1); } + + +cobalt::promise do_write(cobalt::channel & c, int times = 1) +{ + while (times --> 0) + co_await c.write(); +}; + +CO_TEST_CASE(interrupt_0_void) +{ + cobalt::channel c{0}; + auto w = do_write(c); + + BOOST_CHECK(!w.ready()); + auto [ec] = co_await cobalt::as_tuple(test_interrupt(c.read())); + BOOST_CHECK_MESSAGE(ec == asio::error::operation_aborted, ec.to_string()); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(!w.ready()); + co_await c.read(); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); +} + + +CO_TEST_CASE(interrupt_1_void) +{ + cobalt::channel c{1}; + auto w = do_write(c, 2); + + BOOST_CHECK(!w.ready()); + auto [ec] = co_await cobalt::as_tuple(test_interrupt(c.read())); + BOOST_CHECK_MESSAGE(ec == asio::error::operation_aborted, ec.to_string()); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(!w.ready()); + co_await c.read(); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); + co_await c.read(); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); +} + +cobalt::promise do_write(cobalt::channel & c, int times = 1) +{ + int i = 0; + while (times --> 0) + co_await c.write(i++); +}; + + +CO_TEST_CASE(interrupt_0_int) +{ + cobalt::channel c{0}; + auto w = do_write(c); + + BOOST_CHECK(!w.ready()); + auto [ec, i] = co_await cobalt::as_tuple(test_interrupt(c.read())); + + BOOST_CHECK_MESSAGE(ec == asio::error::operation_aborted, ec.to_string()); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(!w.ready()); + i = co_await c.read(); + BOOST_CHECK_EQUAL(i, 0); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); +} + + +CO_TEST_CASE(interrupt_1_int) +{ + cobalt::channel c{1}; + auto w = do_write(c, 2); + + BOOST_CHECK(!w.ready()); + auto [ec, i] = co_await cobalt::as_tuple(test_interrupt(c.read())); + BOOST_CHECK_MESSAGE(ec == asio::error::operation_aborted, ec.to_string()); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(!w.ready()); + i = co_await c.read(); + BOOST_CHECK_EQUAL(i, 0); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); + i = co_await c.read(); + BOOST_CHECK_EQUAL(i, 1); + co_await asio::post(co_await this_coro::executor); + BOOST_CHECK(w.ready()); +} + } diff --git a/test/test.hpp b/test/test.hpp index 8ff006d..69f17b5 100644 --- a/test/test.hpp +++ b/test/test.hpp @@ -173,4 +173,47 @@ struct posted_handle } }; +template +struct test_interrupt +{ + Aw aw; + + test_interrupt(Aw && aw) : aw(std::move(aw)) {} + + bool await_ready() + { + auto res = aw.await_ready(); + aw.interrupt_await(); + return res; + } + + template + auto await_suspend(std::coroutine_handle h) + { + using type = decltype(aw.await_suspend(h)); + if constexpr (std::is_void_v) + { + aw.await_suspend(h); + aw.interrupt_await(); + } + else + { + auto r = aw.await_suspend(h); + aw.interrupt_await(); + return r; + } + } + + template + auto await_resume(const T & tag) + { + return aw.await_resume(tag); + } + + auto await_resume() + { + return aw.await_resume(); + } +}; + #endif //BOOST_COBALT_TEST2_HPP