diff --git a/README.md b/README.md index a8778441..bce9f5f7 100644 --- a/README.md +++ b/README.md @@ -9,34 +9,30 @@ The general form of the read and write operations of a redis client that support push notifications and pipelines looks like the following ```cpp -net::awaitable reader(net::ip::tcp::resolver::results_type const& res) +net::awaitable +example(net::ip::tcp::socket& socket, std::queue& pipelines) { - auto ex = co_await net::this_coro::executor; - - net::ip::tcp::socket socket{ex}; - co_await net::async_connect(socket, res, net::use_awaitable); - - std::string read_buffer; - response_buffers buffers; - std::queue pipelines; - pipelines.push({}); pipelines.back().hello("3"); + std::string buffer; + response_buffers buffers; + response_adapters adapters{buffers}; + consumer_state cs; + for (;;) { - co_await async_write_some(socket, pipelines, net::use_awaitable); + auto const type = + co_await async_consume( + socket, buffer, pipelines, adapters, cs, net::use_awaitable); - do { - do { - auto const event = co_await async_read_one(socket, read_buffer, buffers, pipelines); - if (event.second != resp3::type::push) - pipelines.front().commands.pop(); + if (type == resp3::type::push) { + // Push received. + continue; + } - // Your code comes here. + auto const cmd = pipelines.front().commands.front(); - } while (!std::empty(pipelines.front().commands)); - pipelines.pop(); - } while (std::empty(pipelines)); + // Response to a specific command. } } ``` @@ -52,7 +48,12 @@ int main() net::io_context ioc; net::ip::tcp::resolver resolver{ioc}; auto const res = resolver.resolve("127.0.0.1", "6379"); - co_spawn(ioc, reader(res), net::detached); + + net::ip::tcp::socket socket{ioc}; + net::connect(socket, res); + + std::queue pipelines; + co_spawn(ioc, example(socket, pipelines), net::detached); ioc.run(); } ``` diff --git a/examples/async_basic.cpp b/examples/async_basic.cpp index 60fac9f7..945e304f 100644 --- a/examples/async_basic.cpp +++ b/examples/async_basic.cpp @@ -9,25 +9,47 @@ using namespace aedis; -void receiver(command cmd, resp3::type type, std::queue& pipelines) +net::awaitable +example(net::ip::tcp::socket& socket, std::queue& pipelines) { - std::cout << "Event: " << cmd << " (" << type << ")" << std::endl; - switch (cmd) { - case command::hello: - { - prepare_queue(pipelines); - pipelines.back().ping(); - pipelines.back().subscribe("some-channel"); - } break; - case command::publish: break; - case command::quit: break; - case command::ping: - { - prepare_queue(pipelines); - pipelines.back().publish("some-channel", "Some message"); - pipelines.back().quit(); - } break; - default: { } + pipelines.push({}); + pipelines.back().hello("3"); + + std::string buffer; + response_buffers buffers; + response_adapters adapters{buffers}; + consumer_state cs; + + for (;;) { + auto const type = + co_await async_consume( + socket, buffer, pipelines, adapters, cs, net::use_awaitable); + + if (type == resp3::type::push) { + std::cout << "Event: " << "(" << type << ")" << std::endl; + continue; + } + + auto const cmd = pipelines.front().commands.front(); + + std::cout << "Event: " << cmd << " (" << type << ")" << std::endl; + switch (cmd) { + case command::hello: + { + prepare_queue(pipelines); + pipelines.back().ping(); + pipelines.back().subscribe("some-channel"); + } break; + case command::publish: break; + case command::quit: break; + case command::ping: + { + prepare_queue(pipelines); + pipelines.back().publish("some-channel", "Some message"); + pipelines.back().quit(); + } break; + default: { } + } } } @@ -36,17 +58,11 @@ int main() net::io_context ioc; net::ip::tcp::resolver resolver{ioc}; auto const res = resolver.resolve("127.0.0.1", "6379"); + net::ip::tcp::socket socket{ioc}; net::connect(socket, res); - std::string buffer; - response_buffers buffers; std::queue pipelines; - - pipelines.push({}); - pipelines.back().hello("3"); - - auto f = [&](auto cmd, auto type) {receiver(cmd, type, pipelines);}; - co_spawn(ioc, async_read(socket, buffer, buffers, pipelines, f), net::detached); + co_spawn(ioc, example(socket, pipelines), net::detached); ioc.run(); } diff --git a/include/aedis/read.hpp b/include/aedis/read.hpp index bea06db7..4c56bab9 100644 --- a/include/aedis/read.hpp +++ b/include/aedis/read.hpp @@ -7,6 +7,8 @@ #pragma once +#include + #include #include #include @@ -16,6 +18,8 @@ #include #include +#include + namespace aedis { response_adapter_base* select_adapter(response_adapters& buffers, resp3::type t, command cmd); @@ -291,39 +295,84 @@ async_consume( co_return res; } -template -net::awaitable -async_read( +struct consume_op { + net::ip::tcp::socket& socket; + std::string& buffer; + std::queue& pipelines; + response_adapters& adapters; + resp3::type& m_type; + net::coroutine& coro; + + template + void operator()( + Self& self, + boost::system::error_code const& ec = {}, + resp3::type type = resp3::type::invalid) + { + reenter (coro) for (;;) + { + yield async_write_some(socket, pipelines, std::move(self)); + if (ec) { + self.complete(ec, resp3::type::invalid); + return; + } + + do { + do { + yield async_read_type(socket, buffer, std::move(self)); + if (ec) { + self.complete(ec, resp3::type::invalid); + return; + } + + m_type = type; + + yield + { + auto cmd = command::unknown; + if (m_type != resp3::type::push) + cmd = pipelines.front().commands.front(); + + auto* adapter = select_adapter(adapters, m_type, cmd); + async_read_one_impl(socket, buffer, *adapter, std::move(self)); + } + + if (ec) { + self.complete(ec, resp3::type::invalid); + return; + } + + yield self.complete(ec, m_type); + + if (m_type != resp3::type::push) + pipelines.front().commands.pop(); + + } while (!std::empty(pipelines.front().commands)); + pipelines.pop(); + } while (std::empty(pipelines)); + } + } +}; + +struct consumer_state { + net::coroutine coro = net::coroutine(); + resp3::type type = resp3::type::invalid; +}; + +template +auto async_consume( net::ip::tcp::socket& socket, std::string& buffer, - response_buffers& buffers, std::queue& pipelines, - Receiver receiver) + response_adapters& adapters, + consumer_state& cs, + CompletionToken&& token) { - for (;;) { - co_await async_write_some(socket, pipelines, net::use_awaitable); - - do { - do { - response_adapters adapters{buffers}; - - auto const type = - co_await async_read_type(socket, buffer, net::use_awaitable); - - auto cmd = command::unknown; - if (type != resp3::type::push) { - cmd = pipelines.front().commands.front(); - pipelines.front().commands.pop(); - } - - auto* adapter = select_adapter(adapters, type, cmd); - co_await async_read_one_impl(socket, buffer, *adapter, net::use_awaitable); - receiver(cmd, type); - - } while (!std::empty(pipelines.front().commands)); - pipelines.pop(); - } while (std::empty(pipelines)); - } + return net::async_compose< + CompletionToken, + void(boost::system::error_code, resp3::type)>( + consume_op{socket, buffer, pipelines, adapters, cs.type, cs.coro}, token, socket); } } // aedis +#include diff --git a/include/aedis/write.hpp b/include/aedis/write.hpp index 13cf909f..a61c3e2f 100644 --- a/include/aedis/write.hpp +++ b/include/aedis/write.hpp @@ -55,7 +55,6 @@ template struct write_some_op { AsyncWriteStream& stream; std::queue& pipelines; - std::size_t counter = 0; net::coroutine coro = net::coroutine(); void @@ -78,7 +77,6 @@ struct write_some_op { break; pipelines.front().sent = true; - ++counter; if (std::empty(pipelines.front().commands)) { // We only pop when all commands in the pipeline has push @@ -88,7 +86,7 @@ struct write_some_op { } } while (!std::empty(pipelines) && std::empty(pipelines.front().commands)); - self.complete(ec, counter); + self.complete(ec); } } }; @@ -104,7 +102,7 @@ async_write_some( { return net::async_compose< CompletionToken, - void(std::error_code, std::size_t)>( + void(boost::system::error_code)>( write_some_op{stream, pipelines}, token, stream); } diff --git a/tests/general.cpp b/tests/general.cpp index 67ea0e09..ceafc027 100644 --- a/tests/general.cpp +++ b/tests/general.cpp @@ -105,20 +105,21 @@ test_general(net::ip::tcp::resolver::results_type const& res) net::ip::tcp::socket socket{ex}; co_await net::async_connect(socket, res, net::use_awaitable); - std::queue reqs; std::string buffer; - prepare_queue(reqs); - reqs.back().hello("3"); + std::queue pipelines; + pipelines.push({}); + pipelines.back().hello("3"); test_general_fill filler; - co_await async_write(socket, net::buffer(reqs.back().payload), net::use_awaitable); + auto tmp = net::buffer(pipelines.back().payload); + co_await async_write(socket, tmp, net::use_awaitable); int push_counter = 0; response_buffers bufs; for (;;) { - auto const event = co_await async_consume(socket, buffer, bufs, reqs); + auto const event = co_await async_consume(socket, buffer, bufs, pipelines); switch (event.second) { case resp3::type::simple_string: @@ -216,10 +217,10 @@ test_general(net::ip::tcp::resolver::results_type const& res) case command::hgetall: check_equal(bufs.map, {"field1", "value1", "field2", "value2"}, "hgetall (value)"); break; case command::hello: { - auto const empty = prepare_queue(reqs); - filler(reqs.back()); + auto const empty = prepare_queue(pipelines); + filler(pipelines.back()); if (empty) - co_await async_write_some(socket, reqs, net::use_awaitable); + co_await async_write_some(socket, pipelines, net::use_awaitable); } break; default: { std::cout << "Error: " << event.first << " " << event.second << std::endl;