From 1c12be9aca344a591130eeb18ab05804901ac379 Mon Sep 17 00:00:00 2001 From: ruben Date: Sun, 17 Nov 2019 21:01:32 +0000 Subject: [PATCH] Modified execute_query to allow async code reuse --- include/mysql/impl/connection_impl.hpp | 14 ++ include/mysql/impl/query.hpp | 17 +- include/mysql/impl/query_impl.hpp | 303 ++++++++++++++++++++----- 3 files changed, 270 insertions(+), 64 deletions(-) diff --git a/include/mysql/impl/connection_impl.hpp b/include/mysql/impl/connection_impl.hpp index 84d9d3a7..d1f35e78 100644 --- a/include/mysql/impl/connection_impl.hpp +++ b/include/mysql/impl/connection_impl.hpp @@ -66,5 +66,19 @@ mysql::resultset mysql::connection::query( return res; } +template +template +BOOST_ASIO_INITFN_RESULT_TYPE(CompletionToken, void(mysql::error_code, mysql::resultset)) +mysql::connection::async_query( + std::string_view query_string, + CompletionToken&& token +) +{ + return detail::async_execute_query( + query_string, + std::forward(token) + ); +} + #endif /* INCLUDE_MYSQL_IMPL_CONNECTION_IMPL_HPP_ */ diff --git a/include/mysql/impl/query.hpp b/include/mysql/impl/query.hpp index 906b9bff..d01f1bb9 100644 --- a/include/mysql/impl/query.hpp +++ b/include/mysql/impl/query.hpp @@ -10,6 +10,12 @@ namespace mysql namespace detail { +template +using channel_stream_type = typename ChannelType::stream_type; + +template +using channel_resultset_type = resultset, Allocator>; + enum class fetch_result { error, @@ -21,10 +27,19 @@ template void execute_query( ChannelType& channel, std::string_view query, - resultset& output, + channel_resultset_type& output, error_code& err ); +template +BOOST_ASIO_INITFN_RESULT_TYPE(CompletionToken, void(error_code, channel_resultset_type)) +async_execute_query( + ChannelType& channel, + std::string_view query, + CompletionToken&& token +); + + template fetch_result fetch_text_row( ChannelType& channel, diff --git a/include/mysql/impl/query_impl.hpp b/include/mysql/impl/query_impl.hpp index f57298be..3424229d 100644 --- a/include/mysql/impl/query_impl.hpp +++ b/include/mysql/impl/query_impl.hpp @@ -4,17 +4,114 @@ #include "mysql/impl/messages.hpp" #include "mysql/impl/basic_serialization.hpp" #include "mysql/impl/deserialize_row.hpp" +#include namespace mysql { namespace detail { -template -using channel_stream_type = typename ChannelType::stream_type; -} -} +template +class query_processor +{ + ChannelType& channel_; + bytestring buffer_; + std::vector fields_; + std::vector> field_buffers_; +public: + query_processor(ChannelType& channel): channel_(channel) {}; + void process_query_request( + std::string_view query + ) + { + // Compose a com_query message + msgs::com_query query_msg; + query_msg.query.value = query; + + // Serialize it + capabilities caps = channel_.current_capabilities(); + serialize_message(query_msg, caps, buffer_); + + // Prepare the channel + channel_.reset_sequence_number(); + } + + std::optional // ok, err, or a number of fields + process_query_response( + channel_resultset_type& output, + error_code& err + ) + { + // Response may be: ok_packet, err_packet, local infile request (TODO) + // If it is none of this, then the message type itself is the beginning of + // a length-encoded int containing the field count + DeserializationContext ctx (boost::asio::buffer(buffer_), channel_.current_capabilities()); + std::uint8_t msg_type; + err = deserialize_message_type(msg_type, ctx); + if (err) return {}; + if (msg_type == ok_packet_header) + { + msgs::ok_packet ok_packet; + err = deserialize_message(ok_packet, ctx); + if (err) return {}; + output = channel_resultset_type(channel_, std::move(buffer_), ok_packet); + err.clear(); + return {}; + } + else if (msg_type == error_packet_header) + { + err = process_error_packet(ctx); + return {}; + } + else + { + // Resultset with metadata. First packet is an int_lenenc with + // the number of field definitions to expect. Message type is part + // of this packet, so we must rewind the context + ctx.rewind(1); + int_lenenc num_fields; + err = deserialize_message(num_fields, ctx); + if (err) return {}; + + fields_.reserve(num_fields.value); + field_buffers_.reserve(num_fields.value); + + return num_fields.value; + } + } + + error_code process_field_definition() + { + msgs::column_definition field_definition; + DeserializationContext ctx (boost::asio::buffer(buffer_), channel_.current_capabilities()); + auto err = deserialize_message(field_definition, ctx); + if (err) return err; + + // Add it to our array + fields_.push_back(field_definition); + field_buffers_.push_back(std::move(buffer_)); + buffer_ = bytestring(); + + return error_code(); + } + + void create_resultset( + channel_resultset_type& output + ) && + { + output = channel_resultset_type( + channel_, + resultset_metadata(std::move(field_buffers_), std::move(fields_)) + ); + } + + auto& channel() { return channel_; } + auto& buffer() { return buffer_; } +}; + +} // detail +} // mysql template void mysql::detail::execute_query( @@ -24,86 +121,166 @@ void mysql::detail::execute_query( error_code& err ) { - // Compose a com_query message - msgs::com_query query_msg; - query_msg.query.value = query; - - // Serialize it - capabilities caps = channel.current_capabilities(); - bytestring buffer; - serialize_message(query_msg, caps, buffer); + // Compose a com_query message, reset seq num + query_processor processor (channel); + processor.process_query_request(query); // Send it - channel.reset_sequence_number(); - channel.write(boost::asio::buffer(buffer), err); + channel.write(boost::asio::buffer(processor.buffer()), err); if (err) return; // Read the response - channel.read(buffer, err); + channel.read(processor.buffer(), err); if (err) return; - // Response may be: ok_packet, err_packet, local infile request (TODO) - // If it is none of this, then the message type itself is the beginning of - // a length-encoded int containing the field count - DeserializationContext ctx (boost::asio::buffer(buffer), caps); - std::uint8_t msg_type; - err = deserialize_message_type(msg_type, ctx); - if (err) return; - if (msg_type == ok_packet_header) + // Response may be: ok_packet, err_packet, local infile request (TODO), or response with fields + auto num_fields = processor.process_query_response(output, err); + if (!num_fields) // ok or err { - msgs::ok_packet ok_packet; - err = deserialize_message(ok_packet, ctx); - if (err) return; - output = resultset, Allocator>(channel, std::move(buffer), ok_packet); - err.clear(); - return; - } - else if (msg_type == error_packet_header) - { - err = process_error_packet(ctx); return; } - // Resultset with metadata. First packet is an int_lenenc with - // the number of field definitions to expect. Message type is part - // of this packet, so we must rewind the context - ctx.set_first(buffer.data()); - int_lenenc num_fields; - err = deserialize_message(num_fields, ctx); - if (err) return; - - std::vector fields; - std::vector> field_buffers; - fields.reserve(num_fields.value); - field_buffers.reserve(num_fields.value); - - // Read all of the field definitions - for (std::uint64_t i = 0; i < num_fields.value; ++i) + // We have a response with metadata, read all of the field definitions + for (std::uint64_t i = 0; i < *num_fields; ++i) { // Read the field definition packet - bytestring field_definition_buffer; - channel.read(field_definition_buffer, err); + channel.read(processor.buffer(), err); if (err) return; - // Deserialize the message - msgs::column_definition field_definition; - ctx = DeserializationContext(boost::asio::buffer(field_definition_buffer), caps); - err = deserialize_message(field_definition, ctx); + // Process the message + err = processor.process_field_definition(); if (err) return; - - // Add it to our array - fields.push_back(field_definition); - field_buffers.push_back(std::move(field_definition_buffer)); } // No EOF packet is expected here, as we require deprecate EOF capabilities - output = resultset, Allocator>( - channel, - resultset_metadata(std::move(field_buffers), std::move(fields)) - ); + std::move(processor).create_resultset(output); err.clear(); } + +template +BOOST_ASIO_INITFN_RESULT_TYPE( + CompletionToken, + void(mysql::error_code, mysql::detail::channel_resultset_type) +) +mysql::detail::async_execute_query( + ChannelType& channel, + std::string_view query, + CompletionToken&& token +) +{ + using HandlerSignature = void(error_code, channel_resultset_type); + using HandlerType = BOOST_ASIO_HANDLER_TYPE(CompletionToken, HandlerSignature); + using StreamType = typename ChannelType::stream_type; + using BaseType = boost::beast::async_base; + using ResultsetType = channel_resultset_type; + + boost::asio::async_completion initiator(token); + + struct Op: BaseType, boost::asio::coroutine + { + query_processor processor_; + + Op( + HandlerType&& handler, + ChannelType& channel, + std::string_view query + ): + BaseType(std::move(handler), channel.next_layer().get_executor()), + processor_(channel) + { + processor_.process_query_request(query); + } + + std::optional process_query_response(bool cont) + { + ResultsetType resultset; + error_code err; + auto num_fields = processor_.process_query_response(resultset, err); + if (!num_fields) // ok or err + { + complete(cont, err, std::move(resultset)); + } + return num_fields; + } + + void complete_with_fields(bool cont) + { + ResultsetType resultset; + std::move(processor_).create_resultset(resultset); + complete(cont, error_code(), std::move(resultset)); + } + + void operator()( + error_code err, + bool cont=true + ) + { + std::optional num_fields; + reenter(*this) + { + // The request message has already been composed in the ctor. Send it + yield processor_.channel().async_write( + boost::asio::buffer(processor_.buffer()), + std::move(*this) + ); + if (err) + { + complete(cont, err, ResultsetType()); + yield break; + } + + // Read the response + yield processor_.channel().read(processor_.buffer(), std::move(*this)); + if (err) + { + complete(cont, err, ResultsetType()); + yield break; + } + + // Response may be: ok_packet, err_packet, local infile request (TODO), or response with fields + num_fields = process_query_response(cont); + if (!num_fields) + { + // Not a response with fields. complete() already called + yield break; + } + + // Read all of the field definitions + for (std::uint64_t i = 0; i < *num_fields; ++i) + { + // Read the field definition packet + yield processor_.channel().async_read(processor_.buffer(), std::move(*this)); + if (!err) + { + // Process the message + err = processor_.process_field_definition(); + } + + if (err) + { + complete(cont, err, ResultsetType()); + yield break; + } + } + + // No EOF packet is expected here, as we require deprecate EOF capabilities + complete_with_fields(cont); + yield break; + } + } + }; + + Op( + std::move(initiator.completion_handler), + channel, + query + )(error_code(), false); + return initiator.result.get(); +} + + + template mysql::detail::fetch_result mysql::detail::fetch_text_row( ChannelType& channel, @@ -146,7 +323,7 @@ mysql::detail::fetch_result mysql::detail::fetch_text_row( } } - +#include #endif /* INCLUDE_MYSQL_IMPL_QUERY_IMPL_HPP_ */