diff --git a/examples/ping.cpp b/examples/ping.cpp index 8c4e4086..c1146b60 100644 --- a/examples/ping.cpp +++ b/examples/ping.cpp @@ -25,10 +25,10 @@ net::awaitable ping() requests.back().push(command::ping); requests.back().push(command::quit); - resp3::stream s; + resp3::stream stream{std::move(socket)}; for (;;) { resp3::response resp; - co_await s.async_consume(socket, requests, resp); + co_await stream.async_consume(requests, resp); std::cout << requests.front().commands.front() << "\n" diff --git a/examples/pubsub.cpp b/examples/pubsub.cpp index 000f6a5a..47eed644 100644 --- a/examples/pubsub.cpp +++ b/examples/pubsub.cpp @@ -22,18 +22,14 @@ net::awaitable publisher() std::queue requests; requests.push({}); requests.back().push(command::hello, 3); + requests.back().push(command::publish, "channel1", "Message to channel1"); + requests.back().push(command::publish, "channel2", "Message to channel2"); + requests.back().push(command::quit); - resp3::stream s; + resp3::stream stream{std::move(socket)}; for (;;) { resp3::response resp; - co_await s.async_consume(socket, requests, resp); - - if (requests.front().commands.front() == command::hello) { - prepare_next(requests); - requests.back().push(command::publish, "channel1", "Message to channel1"); - requests.back().push(command::publish, "channel2", "Message to channel2"); - requests.back().push(command::quit); - } + co_await stream.async_consume(requests, resp); } } @@ -47,21 +43,29 @@ net::awaitable subscriber() std::queue requests; requests.push({}); requests.back().push(command::hello, "3"); + requests.back().push(command::subscribe, "channel1", "channel2"); - resp3::stream s; + resp3::stream stream{std::move(socket)}; for (;;) { resp3::response resp; - co_await s.async_consume(socket, requests, resp); + co_await stream.async_consume(requests, resp); if (resp.get_type() == resp3::type::push) { - std::cout << "Subscriber " << id << ":\n" << resp << std::endl; + std::cout + << "Subscriber " << id << "\n" + << resp << std::endl; continue; } - if (requests.front().commands.front() == command::hello) { - id = resp.raw().at(8).data; - prepare_next(requests); - requests.back().push(command::subscribe, "channel1", "channel2"); + auto const cmd = requests.front().commands.front(); + switch (cmd) { + case command::hello: + id = resp.raw().at(8).data; + break; + default: + std::cout + << cmd << "\n" + << resp << std::endl; } } } diff --git a/include/aedis/resp3/stream.hpp b/include/aedis/resp3/stream.hpp index 5e9a688c..34014e67 100644 --- a/include/aedis/resp3/stream.hpp +++ b/include/aedis/resp3/stream.hpp @@ -19,28 +19,73 @@ namespace resp3 { /** Reads and writes redis commands. */ -struct stream { - std::string buffer; - net::coroutine coro = net::coroutine(); - type t = type::invalid; +template +class stream { +public: + /// The type of the next layer. + using next_layer_type = typename std::remove_reference::type; - template< - class AsyncReadWriteStream, - class CompletionToken = - net::default_completion_token_t - > + /// The type of the executor associated with the object. + using executor_type = typename next_layer_type::executor_type; + +private: + std::string buffer_; + net::coroutine coro_ = net::coroutine(); + type type_ = type::invalid; + next_layer_type next_layer_; + +public: + template + stream(Arg&& arg) + : next_layer_(std::forward(arg)) + { } + + stream(stream&& other) = default; + stream& operator=(stream&& other) = delete; + + /// Get the executor associated with the object. + /** + * This function may be used to obtain the executor object that the stream + * uses to dispatch handlers for asynchronous operations. + * + * @return A copy of the executor that stream will use to dispatch handlers. + */ + executor_type get_executor() const noexcept + { return next_layer_.lowest_layer().get_executor(); } + + /// Get a reference to the next layer. + /** + * This function returns a reference to the next layer in a stack of + * stream layers. + * + * @return A reference to the next layer in the stack of stream + * layers. Ownership is not transferred to the caller. + */ + next_layer_type const& next_layer() const + { return next_layer_; } + + /// Get a reference to the next layer. + /** + * This function returns a reference to the next layer in a stack + * of stream layers. + * + * @return A reference to the next layer in the stack of stream + * layers. Ownership is not transferred to the caller. + */ + next_layer_type& next_layer() + { return next_layer_; } + + template> auto async_consume( - AsyncReadWriteStream& stream, std::queue& requests, response& resp, - CompletionToken&& token = - net::default_completion_token_t{}) + CompletionToken&& token = net::default_completion_token_t{}) { return net::async_compose< - CompletionToken, - void(boost::system::error_code, type)>( - detail::consumer_op{stream, buffer, requests, resp, t, coro}, - token, stream); + CompletionToken, void(boost::system::error_code, type)>( + detail::consumer_op + {next_layer_, buffer_, requests, resp, type_, coro_}, + token, next_layer_); } }; diff --git a/tests/general.cpp b/tests/general.cpp index 77cd83ac..bc931684 100644 --- a/tests/general.cpp +++ b/tests/general.cpp @@ -122,12 +122,12 @@ test_general(net::ip::tcp::resolver::results_type const& res) test_general_fill filler; resp3::response resp; - resp3::stream s; + resp3::stream stream{std::move(socket)}; int push_counter = 0; for (;;) { resp.clear(); - co_await s.async_consume(socket, requests, resp, net::use_awaitable); + co_await stream.async_consume(requests, resp, net::use_awaitable); if (resp.get_type() == resp3::type::push) { switch (push_counter) {