From f6fffb80daab432c0b76307a92b44bc3e47c29ba Mon Sep 17 00:00:00 2001 From: ruben Date: Tue, 10 Sep 2019 20:26:47 +0100 Subject: [PATCH] Made MysqlStream, PreparedStatement be templates Made MysqlStream::write a template accepting any ConstBufferSequence --- CMakeLists.txt | 2 - include/basic_serialization.hpp | 1 + .../impl/mysql_stream_impl.hpp | 97 ++++-- include/impl/prepared_statement_impl.hpp | 327 ++++++++++++++++++ include/mysql_stream.hpp | 14 +- include/prepared_statement.hpp | 76 +--- main.cpp | 10 +- src/prepared_statement.cpp | 254 -------------- 8 files changed, 419 insertions(+), 362 deletions(-) rename src/mysql_stream.cpp => include/impl/mysql_stream_impl.hpp (61%) create mode 100644 include/impl/prepared_statement_impl.hpp delete mode 100644 src/prepared_statement.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e8960cb..74672eb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,9 +13,7 @@ add_library( SHARED src/basic_serialization.cpp src/message_serialization.cpp - src/mysql_stream.cpp src/auth.cpp - src/prepared_statement.cpp ) target_link_libraries( mysql_asio diff --git a/include/basic_serialization.hpp b/include/basic_serialization.hpp index d637a96e..8ed36885 100644 --- a/include/basic_serialization.hpp +++ b/include/basic_serialization.hpp @@ -118,6 +118,7 @@ public: void* data() { return buffer_.data(); } std::size_t size() const { return buffer_.size(); } const std::vector& get() const { return buffer_; } + void clear() { buffer_.clear(); } }; template void native_to_little(T& value) { boost::endian::native_to_little_inplace(value); } diff --git a/src/mysql_stream.cpp b/include/impl/mysql_stream_impl.hpp similarity index 61% rename from src/mysql_stream.cpp rename to include/impl/mysql_stream_impl.hpp index e7c3905f..07bfc859 100644 --- a/src/mysql_stream.cpp +++ b/include/impl/mysql_stream_impl.hpp @@ -1,11 +1,6 @@ -/* - * mysql_stream.cpp - * - * Created on: Jul 7, 2019 - * Author: ruben - */ +#ifndef INCLUDE_IMPL_MYSQL_STREAM_IMPL_HPP_ +#define INCLUDE_IMPL_MYSQL_STREAM_IMPL_HPP_ -#include "mysql_stream.hpp" #include "message_serialization.hpp" #include "auth.hpp" #include @@ -14,7 +9,10 @@ #include #include -using namespace std; +namespace mysql +{ +namespace detail +{ template constexpr bool all_set(T1 input, Flags... flags) @@ -22,7 +20,7 @@ constexpr bool all_set(T1 input, Flags... flags) return ((input & flags) && ...); } -static void check_capabilities(mysql::int4 server_capabilities) +inline void check_capabilities(mysql::int4 server_capabilities) { bool ok = all_set(server_capabilities, mysql::CLIENT_CONNECT_WITH_DB, @@ -35,7 +33,7 @@ static void check_capabilities(mysql::int4 server_capabilities) throw std::runtime_error { "Missing server capabilities, server not supported" }; } -static void check_authentication_method(const mysql::Handshake& handshake) +inline void check_authentication_method(const mysql::Handshake& handshake) { if (handshake.auth_plugin_name.value != "mysql_native_password") throw std::runtime_error { "Unsupported authentication method" }; // TODO: we should be responding with our method @@ -50,7 +48,11 @@ constexpr mysql::int4 BASIC_CAPABILITIES_FLAGS = mysql::CLIENT_DEPRECATE_EOF | mysql::CLIENT_CONNECT_WITH_DB; -mysql::int1 mysql::get_message_type(const std::vector& buffer, bool check_err) +} +} + + +inline mysql::int1 mysql::get_message_type(const std::vector& buffer, bool check_err) { mysql::int1 res; ReadIterator current = mysql::deserialize(buffer, res); @@ -66,14 +68,15 @@ mysql::int1 mysql::get_message_type(const std::vector& buffer, boo return res; } - -void mysql::MysqlStream::process_sequence_number(int1 got) +template +void mysql::MysqlStream::process_sequence_number(int1 got) { if (got != sequence_number_++) throw std::runtime_error { "Mismatched sequence number" }; } -void mysql::MysqlStream::read(std::vector& buffer) +template +void mysql::MysqlStream::read(std::vector& buffer) { PacketHeader header; uint8_t header_buffer [4]; @@ -89,27 +92,51 @@ void mysql::MysqlStream::read(std::vector& buffer) } while (header.packet_size.value == 0xffffff); } -void mysql::MysqlStream::write(const std::vector& buffer) +template +void mysql::MysqlStream::write(const std::vector& buffer) +{ + write(boost::asio::buffer(buffer.data(), buffer.size())); +} + +template +template +void mysql::MysqlStream::write(ConstBufferSequence&& buffers) { PacketHeader header; - DynamicBuffer header_buffer; // TODO: change to a plain uint8_t when we generalize this + DynamicBuffer header_buffer; // TODO: change to a plain uint8_t when we generalize serialization std::size_t current_size = 0; - while (current_size < buffer.size()) + constexpr std::size_t MAX_PACKET_SIZE = 0xffffff; + + auto first = boost::asio::buffer_sequence_begin(buffers); + auto last = boost::asio::buffer_sequence_end(buffers); + + // TODO: we can do this better - for a multi-element + // buffer sequence, we could merge some of the data into + // a single packet + for (auto it = first; it != last; ++it) { - auto size_to_write = static_cast(std::min( - std::vector::size_type(0xffffff), - buffer.size() - current_size - )); - header.packet_size.value = size_to_write; - header.sequence_number = sequence_number_++; - serialize(header_buffer, header); - boost::asio::write(next_layer_, boost::asio::buffer(header_buffer.data(), header_buffer.size())); - boost::asio::write(next_layer_, boost::asio::buffer(buffer.data() + current_size, size_to_write)); - current_size += size_to_write; + current_size = 0; + auto bufsize = it->size(); + while (current_size < bufsize) + { + auto size_to_write = static_cast(std::min( + MAX_PACKET_SIZE, + bufsize - current_size + )); + header.packet_size.value = size_to_write; + header.sequence_number = sequence_number_++; + header_buffer.clear(); + serialize(header_buffer, header); + // TODO: we could use a buffer sequence to write these two + boost::asio::write(next_layer_, boost::asio::buffer(header_buffer.data(), header_buffer.size())); + boost::asio::write(next_layer_, boost::asio::buffer(*it + current_size, size_to_write)); + current_size += size_to_write; + } } } -void mysql::MysqlStream::handshake(const HandshakeParams& params) +template +void mysql::MysqlStream::handshake(const HandshakeParams& params) { std::vector read_buffer; DynamicBuffer write_buffer; @@ -128,22 +155,22 @@ void mysql::MysqlStream::handshake(const HandshakeParams& params) deserialize(read_buffer.data()+1, read_buffer.data() + read_buffer.size(), handshake); // Process the handshake - check_capabilities(handshake.capability_falgs); - check_authentication_method(handshake); - cout << handshake << "\n\n"; + detail::check_capabilities(handshake.capability_falgs); + detail::check_authentication_method(handshake); + std::cout << handshake << "\n\n"; // Response mysql::HandshakeResponse handshake_response; mysql_native_password::response_buffer auth_response; mysql_native_password::compute_auth_string(params.password, handshake.auth_plugin_data.data(), auth_response); - handshake_response.client_flag = BASIC_CAPABILITIES_FLAGS; + handshake_response.client_flag = detail::BASIC_CAPABILITIES_FLAGS; handshake_response.max_packet_size = 0xffff; handshake_response.character_set = params.character_set; handshake_response.username.value = params.username; - handshake_response.auth_response.value = string_view {(const char*)auth_response, sizeof(auth_response)}; + handshake_response.auth_response.value = std::string_view {(const char*)auth_response, sizeof(auth_response)}; handshake_response.client_plugin_name.value = "mysql_native_password"; handshake_response.database.value = params.database; - cout << handshake_response << "\n\n"; + std::cout << handshake_response << "\n\n"; // Serialize and send serialize(write_buffer, handshake_response); @@ -162,4 +189,4 @@ void mysql::MysqlStream::handshake(const HandshakeParams& params) std::cout << "Connected to server\n"; } - +#endif diff --git a/include/impl/prepared_statement_impl.hpp b/include/impl/prepared_statement_impl.hpp new file mode 100644 index 00000000..2fbe90f7 --- /dev/null +++ b/include/impl/prepared_statement_impl.hpp @@ -0,0 +1,327 @@ +#ifndef INCLUDE_IMPL_PREPARED_STATEMENT_IMPL_HPP_ +#define INCLUDE_IMPL_PREPARED_STATEMENT_IMPL_HPP_ + +#include "message_serialization.hpp" +#include "mysql_stream.hpp" +#include "null_bitmap.hpp" +#include + +namespace mysql +{ +namespace detail +{ + +inline void fill_execute_msg_impl(std::vector::iterator) {} + +template +void fill_execute_msg_impl( + std::vector::iterator param_output, + Param0&& param, + Params&&... tail +) +{ + *param_output = std::forward(param); + fill_execute_msg_impl(std::next(param_output), std::forward(tail)...); +} + + +template +void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... args) +{ + if (sizeof...(args) != num_params) + { + throw std::out_of_range {"Wrong number of parameters passed to prepared statement"}; + } + output.num_params = static_cast(num_params); + output.new_params_bind_flag = 1; + output.param_values.resize(num_params); + fill_execute_msg_impl(output.param_values.begin(), std::forward(args)...); +} + +template +std::vector read_fields( + MysqlStream& stream, + std::size_t quantity +) +{ + std::vector res (quantity); + for (auto& elm: res) + { + stream.read(elm.packet); + deserialize(elm.packet, elm.value); + } + return res; +} + +template +BinaryValue deserialize_field_impl( + ReadIterator& first, + ReadIterator last +) +{ + T value; + first = deserialize(first, last, value); + return BinaryValue {value}; +} + +inline BinaryValue not_implemented() +{ + throw std::runtime_error{"Not implemented"}; +} + +inline BinaryValue deserialize_field( + FieldType type, + ReadIterator& first, + ReadIterator last +) +{ + switch (type) + { + case FieldType::DECIMAL: + case FieldType::VARCHAR: + case FieldType::BIT: + case FieldType::NEWDECIMAL: + case FieldType::ENUM: + case FieldType::SET: + case FieldType::TINY_BLOB: + case FieldType::MEDIUM_BLOB: + case FieldType::LONG_BLOB: + case FieldType::BLOB: + case FieldType::VAR_STRING: + case FieldType::STRING: + case FieldType::GEOMETRY: + return deserialize_field_impl(first, last); + case FieldType::TINY: + return deserialize_field_impl(first, last); + case FieldType::SHORT: + return deserialize_field_impl(first, last); + case FieldType::INT24: + case FieldType::LONG: + return deserialize_field_impl(first, last); + case FieldType::LONGLONG: + return deserialize_field_impl(first, last); + case FieldType::FLOAT: + return not_implemented(); + case FieldType::DOUBLE: + return not_implemented(); + case FieldType::NULL_: + return deserialize_field_impl(first, last); + case FieldType::TIMESTAMP: + case FieldType::DATE: + case FieldType::TIME: + case FieldType::DATETIME: + case FieldType::YEAR: + default: + return not_implemented(); + } +} + +inline void deserialize_binary_row( + const std::vector& packet, + const std::vector& fields, + std::vector& output +) +{ + output.clear(); + output.reserve(fields.size()); + ResultsetRowNullBitmapTraits traits {fields.size()}; + ReadIterator null_bitmap_first = packet.data() + 1; // skip header + ReadIterator current = null_bitmap_first + traits.byte_count(); + ReadIterator last = packet.data() + packet.size(); + + for (std::size_t i = 0; i < fields.size(); ++i) + { + if (traits.is_null(null_bitmap_first, i)) + output.emplace_back(nullptr); + else + output.push_back(deserialize_field(fields[i].value.type, current, last)); + } + if (current != last) + { + throw std::out_of_range {"Leftover data after binary row"}; + } +} + +} // detail +} // mysql + + +template +mysql::PreparedStatement::PreparedStatement( + MysqlStream& stream, + int4 statement_id, + std::vector&& params, + std::vector&& columns +) : + stream_ {&stream}, + statement_id_ {statement_id}, + params_ {std::move(params)}, + columns_ {std::move(columns)} +{ +}; + +template +mysql::PreparedStatement mysql::PreparedStatement::prepare( + MysqlStream& stream, + std::string_view query +) +{ + // Write the prepare request + StmtPrepare packet {{query}}; + DynamicBuffer write_buffer; + serialize(write_buffer, packet); + stream.reset_sequence_number(); + stream.write(write_buffer.get()); + + // Get the prepare response + std::vector read_buffer; + stream.read(read_buffer); + int1 status = get_message_type(read_buffer); + if (status != 0) + throw std::runtime_error {"Error preparing statement" + std::string{query}}; + StmtPrepareResponseHeader response; + deserialize(read_buffer.data() + 1, read_buffer.data() + read_buffer.size(), response); + + // Read the parameters and columns if any + auto params = detail::read_fields(stream, response.num_params); + auto fields = detail::read_fields(stream, response.num_columns); + + return PreparedStatement { + stream, + response.statement_id, + std::move(params), + std::move(fields) + }; +} + +template +template +mysql::BinaryResultset mysql::PreparedStatement::execute_with_cursor( + int4 fetch_count, + Params&&... actual_params +) +{ + int1 flags = fetch_count == MAX_FETCH_COUNT ? CURSOR_TYPE_NO_CURSOR : CURSOR_TYPE_READ_ONLY; + StmtExecute message + { + statement_id_, + flags + }; + detail::fill_execute_msg(message, params_.size(), std::forward(actual_params)...); + return do_execute(message, fetch_count); +} + +template +void mysql::BinaryResultset::read_metadata() +{ + stream_->read(current_packet_); + if (get_message_type(current_packet_) == ok_packet_header) // Implicitly checks for errors + { + process_ok(); + } + else + { + // Header containing number of fields + StmtExecuteResponseHeader response_header; + deserialize(current_packet_, response_header); + + // Fields + fields_ = detail::read_fields(*stream_, response_header.num_fields); + + // First row + retrieve_next(); + } +} + +template +void mysql::BinaryResultset::process_ok() +{ + deserialize(current_packet_.data() + 1, + current_packet_.data() + current_packet_.size(), + ok_packet_); + if (cursor_exists() && + !(ok_packet_.status_flags & SERVER_STATUS_LAST_ROW_SENT)) + { + send_fetch(); + retrieve_next(); + } + else + { + state_ = State::exhausted; + } +} + +template +void mysql::BinaryResultset::send_fetch() +{ + mysql::StmtFetch msg { statement_id_, fetch_count_ }; + DynamicBuffer buffer; + serialize(buffer, msg); + stream_->reset_sequence_number(); + stream_->write(buffer.get()); +} + +template +bool mysql::BinaryResultset::retrieve_next() +{ + if (state_ == State::exhausted) + return false; + + stream_->read(current_packet_); + auto msg_type = get_message_type(current_packet_); + if (msg_type == eof_packet_header) + { + process_ok(); + } + else + { + detail::deserialize_binary_row(current_packet_, fields_, current_values_); + state_ = State::data_available; + } + return more_data(); +} + +template +const mysql::OkPacket& mysql::BinaryResultset::ok_packet() const +{ + assert(state_ == State::exhausted || + (state_ == State::data_available && cursor_exists())); + return ok_packet_; +} + +template +const std::vector& mysql::BinaryResultset::values() const +{ + assert(state_ == State::data_available); + return current_values_; +} + +template +mysql::BinaryResultset mysql::PreparedStatement::do_execute( + const StmtExecute& message, + int4 fetch_count +) +{ + std::vector read_buffer; + + DynamicBuffer write_buffer; + serialize(write_buffer, message); + stream_->reset_sequence_number(); + stream_->write(write_buffer.get()); + + return mysql::BinaryResultset {*stream_, statement_id_, fetch_count}; +} + +template +void mysql::PreparedStatement::close() +{ + assert(statement_id_ != 0); + StmtClose msg { statement_id_ }; + + DynamicBuffer write_buffer; + serialize(write_buffer, msg); + stream_->reset_sequence_number(); + stream_->write(write_buffer.get()); +} + +#endif diff --git a/include/mysql_stream.hpp b/include/mysql_stream.hpp index 76b8ffee..40579e98 100644 --- a/include/mysql_stream.hpp +++ b/include/mysql_stream.hpp @@ -1,7 +1,6 @@ #ifndef INCLUDE_MYSQL_STREAM_HPP_ #define INCLUDE_MYSQL_STREAM_HPP_ -#include #include #include #include "basic_types.hpp" @@ -20,11 +19,13 @@ struct HandshakeParams std::string_view database; }; -int1 get_message_type(const std::vector& buffer, bool check_err=true); +inline int1 get_message_type(const std::vector& buffer, bool check_err=true); +template class MysqlStream { - boost::asio::ip::tcp::socket next_layer_; // to be converted to an async stream + // TODO: static asserts + AsyncStream next_layer_; // to be converted to an async stream int1 sequence_number_ {0}; void process_sequence_number(int1 got); @@ -34,12 +35,17 @@ public: void handshake(const HandshakeParams&); void read(std::vector& buffer); + void write(const std::vector& buffer); + + template + void write(ConstBufferSequence&& buffers); + void reset_sequence_number() { sequence_number_ = 0; } }; } - +#include "impl/mysql_stream_impl.hpp" #endif /* INCLUDE_MYSQL_STREAM_HPP_ */ diff --git a/include/prepared_statement.hpp b/include/prepared_statement.hpp index eaa00da1..0982bfe5 100644 --- a/include/prepared_statement.hpp +++ b/include/prepared_statement.hpp @@ -17,11 +17,12 @@ struct ParamDefinition // TODO: copies should be disallowed }; +template class BinaryResultset { enum class State { initial, data_available, exhausted }; - MysqlStream* stream_; + MysqlStream* stream_; int4 statement_id_; int4 fetch_count_; std::vector fields_; @@ -35,7 +36,7 @@ class BinaryResultset void send_fetch(); bool cursor_exists() const { return ok_packet_.status_flags & SERVER_STATUS_CURSOR_EXISTS; } public: - BinaryResultset(MysqlStream& stream, int4 stmt_id, int4 fetch_count): + BinaryResultset(MysqlStream& stream, int4 stmt_id, int4 fetch_count): stream_ {&stream}, statement_id_ {stmt_id}, fetch_count_ {fetch_count}, ok_packet_ {}, state_ {State::initial} { read_metadata(); }; BinaryResultset(const BinaryResultset&) = delete; @@ -49,18 +50,19 @@ public: bool more_data() const { return state_ != State::exhausted; } }; +template class PreparedStatement { - MysqlStream* stream_; + MysqlStream* stream_; int4 statement_id_; std::vector params_; std::vector columns_; - BinaryResultset do_execute(const StmtExecute& message, int4 fetch_count); + BinaryResultset do_execute(const StmtExecute& message, int4 fetch_count); public: static constexpr int4 MAX_FETCH_COUNT = std::numeric_limits::max(); - PreparedStatement(MysqlStream& stream, int4 statement_id, + PreparedStatement(MysqlStream& stream, int4 statement_id, std::vector&& params, std::vector&& columns); PreparedStatement(const PreparedStatement&) = delete; PreparedStatement(PreparedStatement&&) = default; @@ -68,77 +70,25 @@ public: PreparedStatement& operator=(PreparedStatement&&) = default; ~PreparedStatement() = default; - MysqlStream& next_layer() const { return *stream_; } + auto& next_layer() const { return *stream_; } int4 statement_id() const { return statement_id_; } const std::vector& params() const { return params_; } const std::vector& columns() const { return columns_; } template - BinaryResultset execute(Params&&... params) { return execute_with_cursor(MAX_FETCH_COUNT, std::forward(params)...); } + BinaryResultset execute(Params&&... params) { return execute_with_cursor(MAX_FETCH_COUNT, std::forward(params)...); } template - BinaryResultset execute_with_cursor(int4 fetch_count, Params&&... params); + BinaryResultset execute_with_cursor(int4 fetch_count, Params&&... params); void close(); // Destructor should try to auto-close - static PreparedStatement prepare(MysqlStream& stream, std::string_view query); + static PreparedStatement prepare(MysqlStream& stream, std::string_view query); }; +} // namespace mysql -// Implementations -namespace detail -{ - -inline void fill_execute_msg_impl(std::vector::iterator) {} - -template -void fill_execute_msg_impl( - std::vector::iterator param_output, - Param0&& param, - Params&&... tail -) -{ - *param_output = std::forward(param); - fill_execute_msg_impl(std::next(param_output), std::forward(tail)...); -} - - - - -template -void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... args) -{ - if (sizeof...(args) != num_params) - { - throw std::out_of_range {"Wrong number of parameters passed to prepared statement"}; - } - output.num_params = static_cast(num_params); - output.new_params_bind_flag = 1; - output.param_values.resize(num_params); - fill_execute_msg_impl(output.param_values.begin(), std::forward(args)...); -} - -} - -} - - -template -mysql::BinaryResultset mysql::PreparedStatement::execute_with_cursor( - int4 fetch_count, - Params&&... actual_params -) -{ - int1 flags = fetch_count == MAX_FETCH_COUNT ? CURSOR_TYPE_NO_CURSOR : CURSOR_TYPE_READ_ONLY; - StmtExecute message - { - statement_id_, - flags - }; - detail::fill_execute_msg(message, params_.size(), std::forward(actual_params)...); - return do_execute(message, fetch_count); -} - +#include "impl/prepared_statement_impl.hpp" #endif /* INCLUDE_PREPARED_STATEMENT_HPP_ */ diff --git a/main.cpp b/main.cpp index b92ee327..5f336484 100644 --- a/main.cpp +++ b/main.cpp @@ -20,7 +20,8 @@ struct VariantPrinter void operator()(std::nullptr_t) const { (*this)("NULL"); } }; -void print(mysql::BinaryResultset& res) +template +void print(mysql::BinaryResultset& res) { for (bool ok = res.more_data(); ok; ok = res.retrieve_next()) { @@ -56,7 +57,7 @@ int main() cout << "Connecting to: " << endpoint << endl; // MYSQL stream - MysqlStream stream {ctx}; + MysqlStream stream {ctx}; // TCP connection stream.next_layer().connect(endpoint); @@ -70,11 +71,12 @@ int main() }); // Prepare a statement - auto stmt = mysql::PreparedStatement::prepare( + auto stmt = mysql::PreparedStatement::prepare( stream, "SELECT * from users WHERE age < ? and first_name <> ?"); auto res = stmt.execute_with_cursor(2, 200, string_lenenc{"hola"}); print(res); - auto make_older = mysql::PreparedStatement::prepare(stream, "UPDATE users SET age = age + 1"); + auto make_older = mysql::PreparedStatement::prepare( + stream, "UPDATE users SET age = age + 1"); res = make_older.execute(); print(res); make_older.close(); diff --git a/src/prepared_statement.cpp b/src/prepared_statement.cpp deleted file mode 100644 index 80ebfbe3..00000000 --- a/src/prepared_statement.cpp +++ /dev/null @@ -1,254 +0,0 @@ - -#include "prepared_statement.hpp" -#include "message_serialization.hpp" -#include "null_bitmap.hpp" -#include - -using namespace std; - -static std::vector read_fields( - mysql::MysqlStream& stream, - std::size_t quantity -) -{ - std::vector res (quantity); - for (auto& elm: res) - { - stream.read(elm.packet); - mysql::deserialize(elm.packet, elm.value); - } - return res; -} - -template -static mysql::BinaryValue deserialize_field_impl( - mysql::ReadIterator& first, - mysql::ReadIterator last -) -{ - T value; - first = mysql::deserialize(first, last, value); - return mysql::BinaryValue {value}; -} - -mysql::BinaryValue not_implemented() -{ - throw std::runtime_error{"Not implemented"}; -} - -static mysql::BinaryValue deserialize_field( - mysql::FieldType type, - mysql::ReadIterator& first, - mysql::ReadIterator last -) -{ - switch (type) - { - case mysql::FieldType::DECIMAL: - case mysql::FieldType::VARCHAR: - case mysql::FieldType::BIT: - case mysql::FieldType::NEWDECIMAL: - case mysql::FieldType::ENUM: - case mysql::FieldType::SET: - case mysql::FieldType::TINY_BLOB: - case mysql::FieldType::MEDIUM_BLOB: - case mysql::FieldType::LONG_BLOB: - case mysql::FieldType::BLOB: - case mysql::FieldType::VAR_STRING: - case mysql::FieldType::STRING: - case mysql::FieldType::GEOMETRY: - return deserialize_field_impl(first, last); - case mysql::FieldType::TINY: - return deserialize_field_impl(first, last); - case mysql::FieldType::SHORT: - return deserialize_field_impl(first, last); - case mysql::FieldType::INT24: - case mysql::FieldType::LONG: - return deserialize_field_impl(first, last); - case mysql::FieldType::LONGLONG: - return deserialize_field_impl(first, last); - case mysql::FieldType::FLOAT: - return not_implemented(); - case mysql::FieldType::DOUBLE: - return not_implemented(); - case mysql::FieldType::NULL_: - return deserialize_field_impl(first, last); - case mysql::FieldType::TIMESTAMP: - case mysql::FieldType::DATE: - case mysql::FieldType::TIME: - case mysql::FieldType::DATETIME: - case mysql::FieldType::YEAR: - default: - return not_implemented(); - } -} - -static void deserialize_binary_row( - const std::vector& packet, - const std::vector& fields, - std::vector& output -) -{ - output.clear(); - output.reserve(fields.size()); - mysql::ResultsetRowNullBitmapTraits traits {fields.size()}; - mysql::ReadIterator null_bitmap_first = packet.data() + 1; // skip header - mysql::ReadIterator current = null_bitmap_first + traits.byte_count(); - mysql::ReadIterator last = packet.data() + packet.size(); - - for (std::size_t i = 0; i < fields.size(); ++i) - { - if (traits.is_null(null_bitmap_first, i)) - output.emplace_back(nullptr); - else - output.push_back(deserialize_field(fields[i].value.type, current, last)); - } - if (current != last) - { - throw std::out_of_range {"Leftover data after binary row"}; - } -} - -mysql::PreparedStatement::PreparedStatement( - MysqlStream& stream, - int4 statement_id, - std::vector&& params, - std::vector&& columns -) : - stream_ {&stream}, - statement_id_ {statement_id}, - params_ {std::move(params)}, - columns_ {std::move(columns)} -{ -}; - -mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, std::string_view query) -{ - // Write the prepare request - StmtPrepare packet {{query}}; - DynamicBuffer write_buffer; - serialize(write_buffer, packet); - stream.reset_sequence_number(); - stream.write(write_buffer.get()); - - // Get the prepare response - std::vector read_buffer; - stream.read(read_buffer); - int1 status = get_message_type(read_buffer); - if (status != 0) - throw std::runtime_error {"Error preparing statement" + std::string{query}}; - StmtPrepareResponseHeader response; - deserialize(read_buffer.data() + 1, read_buffer.data() + read_buffer.size(), response); - - // Read the parameters and columns if any - auto params = read_fields(stream, response.num_params); - auto fields = read_fields(stream, response.num_columns); - - return PreparedStatement {stream, response.statement_id, move(params), move(fields)}; -} - -void mysql::BinaryResultset::read_metadata() -{ - stream_->read(current_packet_); - if (get_message_type(current_packet_) == ok_packet_header) // Implicitly checks for errors - { - process_ok(); - } - else - { - // Header containing number of fields - StmtExecuteResponseHeader response_header; - deserialize(current_packet_, response_header); - - // Fields - fields_ = read_fields(*stream_, response_header.num_fields); - - // First row - retrieve_next(); - } -} - -void mysql::BinaryResultset::process_ok() -{ - deserialize(current_packet_.data() + 1, - current_packet_.data() + current_packet_.size(), - ok_packet_); - if (cursor_exists() && - !(ok_packet_.status_flags & SERVER_STATUS_LAST_ROW_SENT)) - { - send_fetch(); - retrieve_next(); - } - else - { - state_ = State::exhausted; - } -} - -void mysql::BinaryResultset::send_fetch() -{ - mysql::StmtFetch msg { statement_id_, fetch_count_ }; - DynamicBuffer buffer; - serialize(buffer, msg); - stream_->reset_sequence_number(); - stream_->write(buffer.get()); -} - -bool mysql::BinaryResultset::retrieve_next() -{ - if (state_ == State::exhausted) - return false; - - stream_->read(current_packet_); - auto msg_type = get_message_type(current_packet_); - if (msg_type == eof_packet_header) - { - process_ok(); - } - else - { - deserialize_binary_row(current_packet_, fields_, current_values_); - state_ = State::data_available; - } - return more_data(); -} - -const mysql::OkPacket& mysql::BinaryResultset::ok_packet() const -{ - assert(state_ == State::exhausted || - (state_ == State::data_available && cursor_exists())); - return ok_packet_; -} - -const std::vector& mysql::BinaryResultset::values() const -{ - assert(state_ == State::data_available); - return current_values_; -} - -mysql::BinaryResultset mysql::PreparedStatement::do_execute( - const StmtExecute& message, - int4 fetch_count -) -{ - std::vector read_buffer; - - DynamicBuffer write_buffer; - serialize(write_buffer, message); - stream_->reset_sequence_number(); - stream_->write(write_buffer.get()); - - return mysql::BinaryResultset {*stream_, statement_id_, fetch_count}; -} - -void mysql::PreparedStatement::close() -{ - assert(statement_id_ != 0); - StmtClose msg { statement_id_ }; - - DynamicBuffer write_buffer; - serialize(write_buffer, msg); - stream_->reset_sequence_number(); - stream_->write(write_buffer.get()); -} -