2
0
mirror of https://github.com/boostorg/mysql.git synced 2026-02-15 01:02:17 +00:00

message_reader impl cleanup

This commit is contained in:
Ruben Perez
2022-08-25 02:52:33 +02:00
parent 237ae21ef2
commit cb51442efa
2 changed files with 314 additions and 262 deletions

View File

@@ -0,0 +1,281 @@
//
// Copyright (c) 2019-2022 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#ifndef BOOST_MYSQL_DETAIL_CHANNEL_IMPL_MESSAGE_READER_HPP
#define BOOST_MYSQL_DETAIL_CHANNEL_IMPL_MESSAGE_READER_HPP
#pragma once
#include <boost/mysql/detail/channel/message_reader.hpp>
boost::asio::const_buffer boost::mysql::detail::message_reader::get_next_message(
std::uint8_t& seqnum,
error_code& ec
) noexcept
{
assert(has_message());
if (seqnum != result_.message.seqnum_first)
{
ec = make_error_code(errc::sequence_number_mismatch);
return {};
}
seqnum = result_.message.seqnum_last + 1;
auto res = buffer_.current_message();
buffer_.move_to_reserved(buffer_.current_message_size());
error_code next_ec = process_message();
if (next_ec)
{
result_.type = result_type::error;
result_.next_error = next_ec;
}
return res;
}
template <class Stream>
void boost::mysql::detail::message_reader::read_some(
Stream& stream,
error_code& ec
)
{
// If we already have a message, complete immediately
if (has_message())
return;
// Remove processed messages if we can
maybe_remove_processed_messages();
while (!has_message())
{
// If any previous process_message indicated that we need more
// buffer space, resize the buffer now
maybe_resize_buffer();
// Actually read bytes
std::size_t bytes_read = stream.read_some(buffer_.free_area(), ec);
if (ec) break;
valgrind_make_mem_defined(buffer_.free_first(), bytes_read);
// Process them
ec = on_read_bytes(bytes_read);
if (ec) break;
}
}
template <class Stream>
struct boost::mysql::detail::message_reader::read_op : boost::asio::coroutine
{
message_reader& reader_;
Stream& stream_;
read_op(message_reader& reader, Stream& stream) noexcept :
reader_(reader),
stream_(stream)
{
}
template<class Self>
void operator()(
Self& self,
error_code ec = {},
std::size_t bytes_read = 0
)
{
// Error handling
if (ec)
{
self.complete(ec);
return;
}
// Non-error path
BOOST_ASIO_CORO_REENTER(*this)
{
// If we already have a message, complete immediately
if (reader_.has_message())
{
BOOST_ASIO_CORO_YIELD boost::asio::post(std::move(self));
self.complete(error_code());
BOOST_ASIO_CORO_YIELD break;
}
// Remove processed messages if we can
reader_.maybe_remove_processed_messages();
while (!reader_.has_message())
{
// If any previous process_message indicated that we need more
// buffer space, resize the buffer now
reader_.maybe_resize_buffer();
// Actually read bytes
BOOST_ASIO_CORO_YIELD stream_.async_read_some(
reader_.buffer_.free_area(),
std::move(*this)
);
valgrind_make_mem_defined(reader_.buffer_.free_first(), bytes_read);
// Process them
ec = reader_.on_read_bytes(bytes_read);
if (ec)
{
self.complete(ec);
BOOST_ASIO_CORO_YIELD break;
}
}
self.complete(error_code());
}
}
};
template <class Stream, class CompletionToken>
BOOST_ASIO_INITFN_AUTO_RESULT_TYPE(CompletionToken, void(::boost::mysql::error_code))
boost::mysql::detail::message_reader::async_read_some(
Stream& stream,
CompletionToken&& token
)
{
return boost::asio::async_compose<CompletionToken, void(error_code)>(
read_op(*this, stream),
token,
stream
);
}
boost::mysql::error_code boost::mysql::detail::message_reader::process_message()
{
// If we have a message, the caller has already read the previous message
// and we need to parse another one. Reset the state
// If the last operation indicated an error, clear it in hope we can recover from it
if (!result_.is_state())
{
result_ = result_t();
}
while (true)
{
if (result_.state.reading_header)
{
// If there are not enough bytes to process a header, request more
if (buffer_.pending_size() < HEADER_SIZE)
{
result_.state.grow_buffer_to_fit = HEADER_SIZE;
return error_code();
}
// Mark the header as belonging to the current message
buffer_.move_to_current_message(HEADER_SIZE);
// Deserialize the header
packet_header header;
deserialization_context ctx (
boost::asio::buffer(buffer_.pending_first(), HEADER_SIZE),
capabilities(0) // unaffected by capabilities
);
errc err = deserialize(ctx, header);
if (err != errc::ok)
{
return make_error_code(err);
}
// Process the sequence number
if (result_.state.is_first_frame)
{
result_.state.is_first_frame = false;
result_.state.first_seqnum = header.sequence_number;
result_.state.last_seqnum = header.sequence_number;
}
else
{
std::uint8_t expected_seqnum = result_.state.last_seqnum + 1;
if (header.sequence_number != expected_seqnum)
{
return make_error_code(errc::sequence_number_mismatch);
}
result_.state.last_seqnum = expected_seqnum;
}
// Process the packet size
result_.state.remaining_bytes = header.packet_size.value;
result_.state.more_frames_follow = (result_.state.remaining_bytes == MAX_PACKET_SIZE);
// We are done with the header
if (result_.state.is_first_frame)
{
// If it's the 1st frame, we can just move the header bytes to the reserved
// area, avoiding a big memmove
buffer_.move_to_reserved(HEADER_SIZE);
}
else
{
buffer_.remove_current_message_last(HEADER_SIZE);
}
result_.state.reading_header = false;
}
if (!result_.state.reading_header)
{
// Get the number of bytes belonging to this message
std::size_t new_bytes = std::min(buffer_.pending_size(), result_.state.remaining_bytes);
// Mark them as belonging to the current message in the buffer
buffer_.move_to_current_message(new_bytes);
// Update remaining bytes
result_.state.remaining_bytes -= new_bytes;
if (result_.state.remaining_bytes == 0)
{
result_.state.reading_header = true;
}
else
{
result_.state.grow_buffer_to_fit = result_.state.remaining_bytes;
return error_code();
}
// If we've fully read a message, we're done
if (!result_.state.remaining_bytes && !result_.state.more_frames_follow)
{
result_.type = result_type::message;
result_.message = message_t {
result_.state.first_seqnum,
result_.state.last_seqnum
};
return error_code();
}
}
}
}
void boost::mysql::detail::message_reader::maybe_remove_processed_messages()
{
if (!keep_messages_)
{
buffer_.remove_reserved();
}
}
void boost::mysql::detail::message_reader::maybe_resize_buffer()
{
if (result_.is_state() && result_.state.grow_buffer_to_fit != 0)
{
buffer_.grow_to_fit(result_.state.grow_buffer_to_fit);
result_.state.grow_buffer_to_fit = 0;
}
}
boost::mysql::error_code boost::mysql::detail::message_reader::on_read_bytes(size_t num_bytes)
{
buffer_.move_to_pending(num_bytes);
return process_message();
}
#endif

View File

@@ -18,6 +18,7 @@
#include <boost/mysql/detail/auxiliar/valgrind.hpp>
#include <boost/mysql/detail/protocol/common_messages.hpp>
#include <boost/mysql/detail/protocol/constants.hpp>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -29,145 +30,31 @@ namespace detail {
class message_reader
{
public:
message_reader(std::size_t initial_buffer_size) :
buffer_(initial_buffer_size)
{
}
message_reader(std::size_t initial_buffer_size) : buffer_(initial_buffer_size) {}
bool keep_messages() const noexcept { return keep_messages_; }
void set_keep_messages(bool v) noexcept { keep_messages_ = v; }
bool has_message() const noexcept { return result_type_ == result_type::message; }
bool has_message() const noexcept { return result_.is_message(); }
const std::uint8_t* buffer_first() const noexcept { return buffer_.reserved_first(); }
boost::asio::const_buffer get_next_message(std::uint8_t& seqnum, error_code& ec) noexcept
{
assert(has_message());
if (seqnum != message_.seqnum_first)
{
ec = make_error_code(errc::sequence_number_mismatch);
return {};
}
seqnum = message_.seqnum_last + 1;
auto res = buffer_.current_message();
buffer_.move_to_reserved(buffer_.current_message_size());
error_code next_ec = process_message();
if (next_ec)
{
result_type_ = result_type::error;
next_error_ = next_ec;
}
return res;
}
inline boost::asio::const_buffer get_next_message(std::uint8_t& seqnum, error_code& ec) noexcept;
// Reads some messages from stream, until there is at least one
template <class Stream>
void read_some(Stream& stream, error_code& ec)
{
// If we already have a message, complete immediately
if (has_message())
return;
// Remove processed messages if we can
maybe_remove_processed_messages();
while (!has_message())
{
// If any previous process_message indicated that we need more
// buffer space, resize the buffer now
maybe_resize_buffer();
// Actually read bytes
std::size_t bytes_read = stream.read_some(buffer_.free_area(), ec);
if (ec) break;
valgrind_make_mem_defined(buffer_.free_first(), bytes_read);
// Process them
ec = on_read_bytes(bytes_read);
if (ec) break;
}
}
void read_some(Stream& stream, error_code& ec);
template <class Stream, class CompletionToken>
BOOST_ASIO_INITFN_AUTO_RESULT_TYPE(CompletionToken, void(error_code))
async_read_some(Stream& stream, CompletionToken&& token)
{
return boost::asio::async_compose<CompletionToken, void(error_code)>(
read_op(*this, stream),
token,
stream
);
}
async_read_some(Stream& stream, CompletionToken&& token);
// Exposed for the sake of testing
read_buffer& buffer() noexcept { return buffer_; }
inline error_code process_message();
std::size_t grow_buffer_to_fit() const noexcept { assert(result_.is_state()); return result_.state.grow_buffer_to_fit; }
private:
template <class Stream>
struct read_op : boost::asio::coroutine
{
message_reader& reader_;
Stream& stream_;
read_op(message_reader& reader, Stream& stream) noexcept :
reader_(reader),
stream_(stream)
{
}
template<class Self>
void operator()(
Self& self,
error_code ec = {},
std::size_t bytes_read = 0
)
{
// Error handling
if (ec)
{
self.complete(ec);
return;
}
// Non-error path
BOOST_ASIO_CORO_REENTER(*this)
{
// If we already have a message, complete immediately
if (reader_.has_message())
{
BOOST_ASIO_CORO_YIELD boost::asio::post(std::move(self));
self.complete(error_code());
BOOST_ASIO_CORO_YIELD break;
}
// Remove processed messages if we can
reader_.maybe_remove_processed_messages();
while (!reader_.has_message())
{
// If any previous process_message indicated that we need more
// buffer space, resize the buffer now
reader_.maybe_resize_buffer();
// Actually read bytes
BOOST_ASIO_CORO_YIELD stream_.async_read_some(
reader_.buffer_.free_area(),
std::move(*this)
);
valgrind_make_mem_defined(reader_.buffer_.free_first(), bytes_read);
// Process them
ec = reader_.on_read_bytes(bytes_read);
if (ec)
{
self.complete(ec);
BOOST_ASIO_CORO_YIELD break;
}
}
self.complete(error_code());
}
}
};
struct read_op;
enum class result_type
{
@@ -184,154 +71,37 @@ private:
struct state_t
{
bool is_first_frame_ {true};
std::uint8_t first_seqnum_ {};
std::uint8_t last_seqnum_ {};
bool reading_header_ {false};
std::size_t remaining_bytes_ {0};
bool more_frames_follow_ {false};
std::size_t grow_buffer_to_fit_ {};
bool is_first_frame {true};
std::uint8_t first_seqnum {};
std::uint8_t last_seqnum {};
bool reading_header {false};
std::size_t remaining_bytes {0};
bool more_frames_follow {false};
std::size_t grow_buffer_to_fit {};
};
read_buffer buffer_;
bool keep_messages_ {false};
// Union-like; the result produced by process_message may be
// either a state (representing an incomplete message), a message
// (representing a complete message) or an error_code
result_type result_type_ {result_type::state};
state_t state_;
message_t message_;
error_code next_error_;
error_code process_message()
struct result_t
{
// If we have a message, the caller has already read the previous message
// and we need to parse another one. Reset the state
// If the last operation indicated an error, clear it in hope we can recover from it
if (result_type_ != result_type::state)
{
result_type_ = result_type::state;
state_ = state_t();
}
result_type type { result_type::state };
state_t state;
message_t message;
error_code next_error;
while (true)
{
if (state_.reading_header_)
{
// If there are not enough bytes to process a header, request more
if (buffer_.pending_size() < HEADER_SIZE)
{
state_.grow_buffer_to_fit_ = HEADER_SIZE;
return error_code();
}
bool is_state() const noexcept { return type == result_type::state; }
bool is_message() const noexcept { return type == result_type::message; }
};
// Mark the header as belonging to the current message
buffer_.move_to_current_message(HEADER_SIZE);
read_buffer buffer_;
bool keep_messages_ {false};
result_t result_;
// Deserialize the header
packet_header header;
deserialization_context ctx (
boost::asio::buffer(buffer_.pending_first(), HEADER_SIZE),
capabilities(0) // unaffected by capabilities
);
errc err = deserialize(ctx, header);
if (err != errc::ok)
{
return make_error_code(err);
}
// Process the sequence number
if (state_.is_first_frame_)
{
state_.is_first_frame_ = false;
state_.first_seqnum_ = header.sequence_number;
state_.last_seqnum_ = header.sequence_number;
}
else
{
std::uint8_t expected_seqnum = state_.last_seqnum_ + 1;
if (header.sequence_number != expected_seqnum)
{
return make_error_code(errc::sequence_number_mismatch);
}
state_.last_seqnum_ = expected_seqnum;
}
// Process the packet size
state_.remaining_bytes_ = header.packet_size.value;
state_.more_frames_follow_ = (state_.remaining_bytes_ == MAX_PACKET_SIZE);
// We are done with the header
if (state_.is_first_frame_)
{
// If it's the 1st frame, we can just move the header bytes to the reserved
// area, avoiding a big memmove
buffer_.move_to_reserved(HEADER_SIZE);
}
else
{
buffer_.remove_current_message_last(HEADER_SIZE);
}
state_.reading_header_ = false;
}
if (!state_.reading_header_)
{
// Get the number of bytes belonging to this message
std::size_t new_bytes = std::min(buffer_.pending_size(), state_.remaining_bytes_);
// Mark them as belonging to the current message in the buffer
buffer_.move_to_current_message(new_bytes);
// Update remaining bytes
state_.remaining_bytes_ -= new_bytes;
if (state_.remaining_bytes_ == 0)
{
state_.reading_header_ = true;
}
else
{
state_.grow_buffer_to_fit_ = state_.remaining_bytes_;
return error_code();
}
// If we've fully read a message, we're done
if (!state_.remaining_bytes_ && !state_.more_frames_follow_)
{
result_type_ = result_type::message;
message_ = message_t {
state_.first_seqnum_,
state_.last_seqnum_
};
return error_code();
}
}
}
}
void maybe_remove_processed_messages()
{
if (!keep_messages_)
{
buffer_.remove_reserved();
}
}
void maybe_resize_buffer()
{
if (result_type_ == result_type::state && state_.grow_buffer_to_fit_ != 0)
{
buffer_.grow_to_fit(state_.grow_buffer_to_fit_);
state_.grow_buffer_to_fit_ = 0;
}
}
error_code on_read_bytes(size_t num_bytes)
{
buffer_.move_to_pending(num_bytes);
return process_message();
}
inline void maybe_remove_processed_messages();
inline void maybe_resize_buffer();
inline error_code on_read_bytes(size_t num_bytes);
};
@@ -339,6 +109,7 @@ private:
} // mysql
} // boost
#include <boost/mysql/detail/channel/message_reader.hpp>
#endif /* INCLUDE_BOOST_MYSQL_DETAIL_AUXILIAR_STATIC_STRING_HPP_ */