diff --git a/include/boost/redis/connection.hpp b/include/boost/redis/connection.hpp index e23bf674..93b7e7fd 100644 --- a/include/boost/redis/connection.hpp +++ b/include/boost/redis/connection.hpp @@ -58,54 +58,157 @@ namespace boost::redis { namespace detail { template -using exec_notifier_type = asio::experimental::channel< - Executor, - void(system::error_code, std::size_t)>; +struct connection_impl { + using clock_type = std::chrono::steady_clock; + using clock_traits_type = asio::wait_traits; + using timer_type = asio::basic_waitable_timer; -template -struct exec_op { - using executor_type = typename Conn::executor_type; + using receive_channel_type = asio::experimental::channel< + Executor, + void(system::error_code, std::size_t)>; + using health_checker_type = detail::health_checker; + using resp3_handshaker_type = detail::resp3_handshaker; + using exec_notifier_type = asio::experimental::channel< + Executor, + void(system::error_code, std::size_t)>; - Conn* conn_ = nullptr; - std::shared_ptr> notifier_ = nullptr; - detail::exec_fsm fsm_; + redis_stream stream_; - template - void operator()(Self& self, system::error_code = {}, std::size_t = 0) - { - while (true) { - // Invoke the state machine - auto act = fsm_.resume(conn_->is_open(), self.get_cancellation_state().cancelled()); + // Notice we use a timer to simulate a condition-variable. It is + // also more suitable than a channel and the notify operation does + // not suspend. + timer_type writer_timer_; + timer_type reconnect_timer_; // to wait the reconnection period + receive_channel_type receive_channel_; + health_checker_type health_checker_; + resp3_handshaker_type handshaker_; - // Do what the FSM said - switch (act.type()) { - case detail::exec_action_type::setup_cancellation: - self.reset_cancellation_state(asio::enable_total_cancellation()); - continue; // this action does not require yielding - case detail::exec_action_type::immediate: - asio::async_immediate(self.get_io_executor(), std::move(self)); - return; - case detail::exec_action_type::notify_writer: - conn_->writer_timer_.cancel(); - continue; // this action does not require yielding - case detail::exec_action_type::wait_for_response: - notifier_->async_receive(std::move(self)); - return; - case detail::exec_action_type::cancel_run: - conn_->cancel(operation::run); - continue; // this action does not require yielding - case detail::exec_action_type::done: - notifier_.reset(); - self.complete(act.error(), act.bytes_read()); - return; + config cfg_; + multiplexer mpx_; + connection_logger logger_; + read_buffer read_buffer_; + + using executor_type = Executor; + + executor_type get_executor() noexcept { return writer_timer_.get_executor(); } + + struct exec_op { + connection_impl* obj_ = nullptr; + std::shared_ptr notifier_ = nullptr; + exec_fsm fsm_; + + template + void operator()(Self& self, system::error_code = {}, std::size_t = 0) + { + while (true) { + // Invoke the state machine + auto act = fsm_.resume(obj_->is_open(), self.get_cancellation_state().cancelled()); + + // Do what the FSM said + switch (act.type()) { + case exec_action_type::setup_cancellation: + self.reset_cancellation_state(asio::enable_total_cancellation()); + continue; // this action does not require yielding + case exec_action_type::immediate: + asio::async_immediate(self.get_io_executor(), std::move(self)); + return; + case exec_action_type::notify_writer: + obj_->writer_timer_.cancel(); + continue; // this action does not require yielding + case exec_action_type::wait_for_response: + notifier_->async_receive(std::move(self)); + return; + case exec_action_type::cancel_run: + obj_->cancel(operation::run); + continue; // this action does not require yielding + case exec_action_type::done: + notifier_.reset(); + self.complete(act.error(), act.bytes_read()); + return; + } } } + }; + + connection_impl(Executor&& ex, asio::ssl::context&& ctx, logger&& lgr) + : stream_{ex, std::move(ctx)} + , writer_timer_{ex} + , reconnect_timer_{ex} + , receive_channel_{ex, 256} + , health_checker_{ex} + , logger_{std::move(lgr)} + { + mpx_.set_receive_response(ignore); + writer_timer_.expires_at((std::chrono::steady_clock::time_point::max)()); + + // Reserve some memory to avoid excessive memory allocations in + // the first reads. + read_buffer_.reserve(4096u); + } + + void cancel(operation op) + { + switch (op) { + case operation::resolve: stream_.cancel_resolve(); break; + case operation::exec: mpx_.cancel_waiting(); break; + case operation::reconnection: + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + break; + case operation::run: cancel_run(); break; + case operation::receive: receive_channel_.cancel(); break; + case operation::health_check: health_checker_.cancel(); break; + case operation::all: + stream_.cancel_resolve(); + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + health_checker_.cancel(); + cancel_run(); // run + receive_channel_.cancel(); // receive + mpx_.cancel_waiting(); // exec + break; + default: /* ignore */; + } + } + + void cancel_run() + { + stream_.close(); + writer_timer_.cancel(); + receive_channel_.cancel(); + mpx_.cancel_on_conn_lost(); + } + + bool is_open() const noexcept { return stream_.is_open(); } + + bool will_reconnect() const noexcept + { + return cfg_.reconnect_wait_interval != std::chrono::seconds::zero(); + } + + template + auto async_exec(request const& req, any_adapter adapter, CompletionToken&& token) + { + auto& adapter_impl = adapter.impl_; + BOOST_ASSERT_MSG( + req.get_expected_responses() <= adapter_impl.supported_response_size, + "Request and response have incompatible sizes."); + + auto notifier = std::make_shared(writer_timer_.get_executor(), 1); + auto info = make_elem(req, std::move(adapter_impl.adapt_fn)); + + info->set_done_callback([notifier]() { + notifier->try_send(std::error_code{}, 0); + }); + + return asio::async_compose( + exec_op{this, notifier, exec_fsm(mpx_, std::move(info))}, + token, + writer_timer_); } }; -template +template struct writer_op { - Conn* conn_; + connection_impl* conn_; asio::coroutine coro{}; template @@ -157,13 +260,13 @@ struct writer_op { } }; -template +template struct reader_op { - Conn* conn_; - detail::reader_fsm fsm_; + connection_impl* conn_; + reader_fsm fsm_; public: - reader_op(Conn& conn) noexcept + reader_op(connection_impl& conn) noexcept : conn_{&conn} , fsm_{conn.read_buffer_, conn.mpx_} { } @@ -214,17 +317,35 @@ inline system::error_code check_config(const config& cfg) return system::error_code{}; } -template +template class run_op { private: - Conn* conn_ = nullptr; + connection_impl* conn_ = nullptr; asio::coroutine coro_{}; system::error_code stored_ec_; using order_t = std::array; + template + auto reader(CompletionToken&& token) + { + return asio::async_compose( + reader_op{*conn_}, + std::forward(token), + conn_->writer_timer_); + } + + template + auto writer(CompletionToken&& token) + { + return asio::async_compose( + writer_op{conn_}, + std::forward(token), + conn_->writer_timer_); + } + public: - run_op(Conn* conn) noexcept + run_op(connection_impl* conn) noexcept : conn_{conn} { } @@ -296,10 +417,10 @@ public: return conn_->health_checker_.async_check_timeout(*conn_, token); }, [this](auto token) { - return conn_->reader(token); + return this->reader(token); }, [this](auto token) { - return conn_->writer(token); + return this->writer(token); }) .async_wait(asio::experimental::wait_for_one_error(), std::move(self)); @@ -382,16 +503,12 @@ public: executor_type ex, asio::ssl::context ctx = asio::ssl::context{asio::ssl::context::tlsv12_client}, logger lgr = {}) - : stream_{ex, std::move(ctx)} - , writer_timer_{ex} - , reconnect_timer_{ex} - , receive_channel_{ex, 256} - , health_checker_{ex} - , logger_{std::move(lgr)} - { - set_receive_response(ignore); - writer_timer_.expires_at((std::chrono::steady_clock::time_point::max)()); - } + : impl_( + std::make_unique>( + std::move(ex), + std::move(ctx), + std::move(lgr))) + { } /** @brief Constructor from an executor and a logger. * @@ -441,7 +558,7 @@ public: { } /// Returns the associated executor. - executor_type get_executor() noexcept { return writer_timer_.get_executor(); } + executor_type get_executor() noexcept { return impl_->writer_timer_.get_executor(); } /** @brief Starts the underlying connection operations. * @@ -487,19 +604,15 @@ public: template > auto async_run(config const& cfg, CompletionToken&& token = {}) { - cfg_ = cfg; - health_checker_.set_config(cfg); - handshaker_.set_config(cfg); - read_buffer_.set_config({cfg.read_buffer_append_size, cfg.max_read_size}); - - // Reserve some memory to avoid excessive memory allocations in - // the first reads. - read_buffer_.reserve(4048u); + impl_->cfg_ = cfg; + impl_->health_checker_.set_config(cfg); + impl_->handshaker_.set_config(cfg); + impl_->read_buffer_.set_config({cfg.read_buffer_append_size, cfg.max_read_size}); return asio::async_compose( - detail::run_op{this}, + detail::run_op{impl_.get()}, token, - writer_timer_); + impl_->writer_timer_); } /** @@ -578,10 +691,10 @@ public: template > auto async_receive(CompletionToken&& token = {}) { - return receive_channel_.async_receive(std::forward(token)); + return impl_->receive_channel_.async_receive(std::forward(token)); } - /** @brief Receives server pushes synchronously without blocking. + /** @brief Receives server> pushes synchronously without blocking. * * Receives a server push synchronously by calling `try_receive` on * the underlying channel. If the operation fails because @@ -600,7 +713,7 @@ public: size = n; }; - auto const res = receive_channel_.try_receive(f); + auto const res = impl_->receive_channel_.try_receive(f); if (ec) return 0; @@ -696,24 +809,7 @@ public: template > auto async_exec(request const& req, any_adapter adapter, CompletionToken&& token = {}) { - auto& adapter_impl = adapter.impl_; - BOOST_ASSERT_MSG( - req.get_expected_responses() <= adapter_impl.supported_response_size, - "Request and response have incompatible sizes."); - - auto notifier = std::make_shared>( - get_executor(), - 1); - auto info = detail::make_elem(req, std::move(adapter_impl.adapt_fn)); - - info->set_done_callback([notifier]() { - notifier->try_send(std::error_code{}, 0); - }); - - return asio::async_compose( - detail::exec_op{this, notifier, detail::exec_fsm(mpx_, std::move(info))}, - token, - writer_timer_); + return impl_->async_exec(req, std::move(adapter), std::forward(token)); } /** @brief Cancel operations. @@ -727,36 +823,12 @@ public: * * @param op The operation to be cancelled. */ - void cancel(operation op = operation::all) - { - switch (op) { - case operation::resolve: stream_.cancel_resolve(); break; - case operation::exec: mpx_.cancel_waiting(); break; - case operation::reconnection: - cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); - break; - case operation::run: cancel_run(); break; - case operation::receive: receive_channel_.cancel(); break; - case operation::health_check: health_checker_.cancel(); break; - case operation::all: - stream_.cancel_resolve(); - cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); - health_checker_.cancel(); - cancel_run(); // run - receive_channel_.cancel(); // receive - mpx_.cancel_waiting(); // exec - break; - default: /* ignore */; - } - } + void cancel(operation op = operation::all) { impl_->cancel(op); } - auto run_is_canceled() const noexcept { return mpx_.get_cancel_run_state(); } + auto run_is_canceled() const noexcept { return impl_->mpx_.get_cancel_run_state(); } /// Returns true if the connection will try to reconnect if an error is encountered. - bool will_reconnect() const noexcept - { - return cfg_.reconnect_wait_interval != std::chrono::seconds::zero(); - } + bool will_reconnect() const noexcept { return impl_->will_reconnect(); } /** * @brief (Deprecated) Returns the ssl context. @@ -770,7 +842,10 @@ public: BOOST_DEPRECATED( "ssl::context has no const methods, so this function should not be called. Set up any " "required TLS configuration before passing the ssl::context to the connection's constructor.") - asio::ssl::context const& get_ssl_context() const noexcept { return stream_.get_ssl_context(); } + asio::ssl::context const& get_ssl_context() const noexcept + { + return impl_->stream_.get_ssl_context(); + } /** * @brief (Deprecated) Resets the underlying stream. @@ -796,7 +871,7 @@ public: BOOST_DEPRECATED( "Accessing the underlying stream is deprecated and will be removed in the next release. Use " "the other member functions to interact with the connection.") - auto& next_layer() noexcept { return stream_.next_layer(); } + auto& next_layer() noexcept { return impl_->stream_.next_layer(); } /** * @brief (Deprecated) Returns a reference to the next layer. @@ -812,17 +887,17 @@ public: BOOST_DEPRECATED( "Accessing the underlying stream is deprecated and will be removed in the next release. Use " "the other member functions to interact with the connection.") - auto const& next_layer() const noexcept { return stream_.next_layer(); } + auto const& next_layer() const noexcept { return impl_->stream_.next_layer(); } /// Sets the response object of @ref async_receive operations. template void set_receive_response(Response& response) { - mpx_.set_receive_response(response); + impl_->mpx_.set_receive_response(response); } /// Returns connection usage information. - usage get_usage() const noexcept { return mpx_.get_usage(); } + usage get_usage() const noexcept { return impl_->mpx_.get_usage(); } private: using clock_type = std::chrono::steady_clock; @@ -835,66 +910,17 @@ private: using health_checker_type = detail::health_checker; using resp3_handshaker_type = detail::resp3_handshaker; - auto use_ssl() const noexcept { return cfg_.use_ssl; } - - void cancel_run() - { - stream_.close(); - writer_timer_.cancel(); - receive_channel_.cancel(); - mpx_.cancel_on_conn_lost(); - } + auto use_ssl() const noexcept { return impl_->cfg_.use_ssl; } // Used by both this class and connection void set_stderr_logger(logger::level lvl, const config& cfg) { - logger_.reset(detail::make_stderr_logger(lvl, cfg.log_prefix)); + impl_->logger_.reset(detail::make_stderr_logger(lvl, cfg.log_prefix)); } - template friend struct detail::reader_op; - template friend struct detail::writer_op; - template friend struct detail::exec_op; - template friend struct detail::hello_op; - template friend class detail::ping_op; - template friend class detail::run_op; - template friend class detail::check_timeout_op; friend class connection; - template - auto reader(CompletionToken&& token) - { - return asio::async_compose( - detail::reader_op{*this}, - std::forward(token), - writer_timer_); - } - - template - auto writer(CompletionToken&& token) - { - return asio::async_compose( - detail::writer_op{this}, - std::forward(token), - writer_timer_); - } - - bool is_open() const noexcept { return stream_.is_open(); } - - detail::redis_stream stream_; - - // Notice we use a timer to simulate a condition-variable. It is - // also more suitable than a channel and the notify operation does - // not suspend. - timer_type writer_timer_; - timer_type reconnect_timer_; // to wait the reconnection period - receive_channel_type receive_channel_; - health_checker_type health_checker_; - resp3_handshaker_type handshaker_; - - config cfg_; - detail::read_buffer read_buffer_; - detail::multiplexer mpx_; - detail::connection_logger logger_; + std::unique_ptr> impl_; }; /** @brief A basic_connection that type erases the executor. @@ -1079,7 +1105,7 @@ public: "the other member functions to interact with the connection.") asio::ssl::stream& next_layer() noexcept { - return impl_.stream_.next_layer(); + return impl_.impl_->stream_.next_layer(); } /// (Deprecated) Calls @ref boost::redis::basic_connection::next_layer. @@ -1088,7 +1114,7 @@ public: "the other member functions to interact with the connection.") asio::ssl::stream const& next_layer() const noexcept { - return impl_.stream_.next_layer(); + return impl_.impl_->stream_.next_layer(); } /// @copydoc basic_connection::reset_stream @@ -1113,7 +1139,7 @@ public: "required TLS configuration before passing the ssl::context to the connection's constructor.") asio::ssl::context const& get_ssl_context() const noexcept { - return impl_.stream_.get_ssl_context(); + return impl_.impl_->stream_.get_ssl_context(); } private: diff --git a/include/boost/redis/detail/health_checker.hpp b/include/boost/redis/detail/health_checker.hpp index ccb942f9..7e4bde72 100644 --- a/include/boost/redis/detail/health_checker.hpp +++ b/include/boost/redis/detail/health_checker.hpp @@ -24,11 +24,11 @@ namespace boost::redis::detail { -template +template class ping_op { public: HealthChecker* checker_ = nullptr; - Connection* conn_ = nullptr; + ConnectionImpl* conn_ = nullptr; asio::coroutine coro_{}; template @@ -155,11 +155,11 @@ public: wait_timer_.cancel(); } - template - auto async_ping(Connection& conn, CompletionToken token) + template + auto async_ping(ConnectionImpl& conn, CompletionToken token) { return asio::async_compose( - ping_op{this, &conn}, + ping_op{this, &conn}, token, conn, ping_timer_); diff --git a/include/boost/redis/detail/resp3_handshaker.hpp b/include/boost/redis/detail/resp3_handshaker.hpp index 41949dfe..05edc795 100644 --- a/include/boost/redis/detail/resp3_handshaker.hpp +++ b/include/boost/redis/detail/resp3_handshaker.hpp @@ -26,10 +26,10 @@ void push_hello(config const& cfg, request& req); // TODO: Can we avoid this whole function whose only purpose is to // check for an error in the hello response and complete with an error // so that the parallel group that starts it can exit? -template +template struct hello_op { Handshaker* handshaker_ = nullptr; - Connection* conn_ = nullptr; + ConnectionImpl* conn_ = nullptr; asio::coroutine coro_{}; template @@ -68,11 +68,11 @@ class resp3_handshaker { public: void set_config(config const& cfg) { cfg_ = cfg; } - template - auto async_hello(Connection& conn, CompletionToken token) + template + auto async_hello(ConnectionImpl& conn, CompletionToken token) { return asio::async_compose( - hello_op{this, &conn}, + hello_op{this, &conn}, token, conn); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ef7bd4a9..24cf9410 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -55,6 +55,7 @@ make_test(test_conn_reconnect) make_test(test_conn_exec_cancel) make_test(test_conn_exec_cancel2) make_test(test_conn_echo_stress) +make_test(test_conn_move) make_test(test_issue_50) make_test(test_issue_181) make_test(test_conversions) diff --git a/test/test_conn_logging.cpp b/test/test_conn_logging.cpp index ab67155e..4c8fe1b5 100644 --- a/test/test_conn_logging.cpp +++ b/test/test_conn_logging.cpp @@ -28,10 +28,6 @@ using namespace boost::redis; namespace { -// user tests -// logging can be disabled -// logging can be changed verbosity - template void run_with_invalid_config(net::io_context& ioc, Conn& conn) { diff --git a/test/test_conn_move.cpp b/test/test_conn_move.cpp new file mode 100644 index 00000000..e7726d6f --- /dev/null +++ b/test/test_conn_move.cpp @@ -0,0 +1,112 @@ +// +// Copyright (c) 2025 Marcelo Zimbres Silva (mzimbres@gmail.com), +// Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "common.hpp" + +#include +#include + +using boost::system::error_code; +namespace net = boost::asio; +using namespace boost::redis; + +namespace { + +// Move constructing a connection doesn't leave dangling pointers +void test_conn_move_construct() +{ + // Setup + net::io_context ioc; + connection conn_prev(ioc); + connection conn(std::move(conn_prev)); + request req; + req.push("PING", "something"); + response res; + + bool run_finished = false, exec_finished = false; + + // Run the connection + conn.async_run(make_test_config(), [&](error_code ec) { + run_finished = true; + BOOST_TEST_EQ(ec, net::error::operation_aborted); + }); + + // Launch a PING + conn.async_exec(req, res, [&](error_code ec, std::size_t) { + exec_finished = true; + BOOST_TEST_EQ(ec, error_code()); + conn.cancel(); + }); + + ioc.run_for(test_timeout); + + // Check + BOOST_TEST(run_finished); + BOOST_TEST(exec_finished); + BOOST_TEST_EQ(std::get<0>(res).value(), "something"); +} + +// Moving a connection is safe even when it's running, +// and it doesn't leave dangling pointers +void test_conn_move_assign_while_running() +{ + // Setup + net::io_context ioc; + connection conn(ioc); + connection conn2(ioc); // will be assigned to + request req; + req.push("PING", "something"); + response res; + + bool run_finished = false, exec_finished = false; + + // Run the connection + conn.async_run(make_test_config(), [&](error_code ec) { + run_finished = true; + BOOST_TEST_EQ(ec, net::error::operation_aborted); + }); + + // Launch a PING. When it finishes, conn will be moved-from, and conn2 will be valid + conn.async_exec(req, res, [&](error_code ec, std::size_t) { + exec_finished = true; + BOOST_TEST_EQ(ec, error_code()); + conn2.cancel(); + }); + + // While the operations are running, perform a move + net::post(net::bind_executor(ioc.get_executor(), [&] { + conn2 = std::move(conn); + })); + + ioc.run_for(test_timeout); + + // Check + BOOST_TEST(run_finished); + BOOST_TEST(exec_finished); + BOOST_TEST_EQ(std::get<0>(res).value(), "something"); +} + +} // namespace + +int main() +{ + test_conn_move_construct(); + test_conn_move_assign_while_running(); + + return boost::report_errors(); +}