diff --git a/include/messages.hpp b/include/messages.hpp index 037036b0..46c82324 100644 --- a/include/messages.hpp +++ b/include/messages.hpp @@ -231,8 +231,9 @@ struct StmtExecute int4 statement_id; int1 flags; // int4 iteration_count: always 1 + int1 num_params; int1 new_params_bind_flag; - std::vector param_values; + std::vector param_values; // empty if !new_params_bind_flag }; struct StmtExecuteResponseHeader diff --git a/include/null_bitmap.hpp b/include/null_bitmap.hpp index 0d9129c6..d30fef63 100644 --- a/include/null_bitmap.hpp +++ b/include/null_bitmap.hpp @@ -12,13 +12,17 @@ class NullBitmapTraits std::size_t num_fields_; public: constexpr NullBitmapTraits(std::size_t num_fields): num_fields_ {num_fields} {}; - std::size_t byte_count() const { return (num_fields_ + 7 + offset) / 8; } - std::size_t byte_pos(std::size_t field_pos) const { return (field_pos + offset) / 8; } - std::size_t bit_pos(std::size_t field_pos) const { return (field_pos + offset) % 8; } - bool is_null(ReadIterator null_bitmap_begin, std::size_t field_pos) const + constexpr std::size_t byte_count() const { return (num_fields_ + 7 + offset) / 8; } + constexpr std::size_t byte_pos(std::size_t field_pos) const { return (field_pos + offset) / 8; } + constexpr std::size_t bit_pos(std::size_t field_pos) const { return (field_pos + offset) % 8; } + bool is_null(const std::uint8_t* null_bitmap_begin, std::size_t field_pos) const { return null_bitmap_begin[byte_pos(field_pos)] & (1 << bit_pos(field_pos)); } + void set_null(std::uint8_t* null_bitmap_begin, std::size_t field_pos) const + { + null_bitmap_begin[byte_pos(field_pos)] |= (1 << bit_pos(field_pos)); + } }; using StmtExecuteNullBitmapTraits = NullBitmapTraits<0>; diff --git a/include/prepared_statement.hpp b/include/prepared_statement.hpp index 9fcabd4f..15d824bc 100644 --- a/include/prepared_statement.hpp +++ b/include/prepared_statement.hpp @@ -13,17 +13,35 @@ struct ParamDefinition { std::vector packet; ColumnDefinition value; + // TODO: copies should be disallowed }; -class BinaryResultsetRow +class BinaryResultset { - const std::vector& fields_; - std::vector packet_; - std::vector values_; + enum class State { initial, data_available, exhausted }; + + MysqlStream* stream_; + std::vector fields_; + std::vector current_packet_; + std::vector current_values_; + OkPacket ok_packet_; + State state_; + + void read_metadata(); + void process_ok(); public: - BinaryResultsetRow(const std::vector& fields, std::vector&& packet); + BinaryResultset(MysqlStream& stream): + stream_ {&stream}, ok_packet_ {}, + state_ {State::initial} { read_metadata(); }; + BinaryResultset(const BinaryResultset&) = delete; + BinaryResultset(BinaryResultset&&) = default; + BinaryResultset& operator=(const BinaryResultset&) = delete; + BinaryResultset& operator=(BinaryResultset&&) = default; + bool retrieve_next(); const std::vector& fields() const { return fields_; } - const std::vector& values() const { return values_; } + const std::vector& values() const; + const OkPacket& ok_packet() const; // Can only be called after exhausted + bool more_data() const { return state_ != State::exhausted; } }; class PreparedStatement @@ -33,7 +51,7 @@ class PreparedStatement std::vector params_; std::vector columns_; - std::vector do_execute(const StmtExecute& message); + BinaryResultset do_execute(const StmtExecute& message); public: PreparedStatement(MysqlStream& stream, int4 statement_id, std::vector&& params, std::vector&& columns); @@ -49,9 +67,7 @@ public: const std::vector& columns() const { return columns_; } template - std::vector execute(Params&&... params); - // execute(Something&&... params): StatementResponse - // execute() // uses already-bound params + BinaryResultset execute(Params&&... params); // close(Connection) // Destructor should try to auto-close @@ -86,6 +102,7 @@ void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... arg { throw std::out_of_range {"Wrong number of parameters passed to prepared statement"}; } + output.num_params = static_cast(num_params); output.new_params_bind_flag = 1; output.param_values.resize(num_params); fill_execute_msg_impl(output.param_values.begin(), std::forward(args)...); @@ -97,7 +114,7 @@ void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... arg template -std::vector mysql::PreparedStatement::execute(Params&&... actual_params) +mysql::BinaryResultset mysql::PreparedStatement::execute(Params&&... actual_params) { StmtExecute message { diff --git a/main.cpp b/main.cpp index 965b6637..1f617f43 100644 --- a/main.cpp +++ b/main.cpp @@ -20,6 +20,23 @@ struct VariantPrinter void operator()(std::nullptr_t) const { (*this)("NULL"); } }; +void print(mysql::BinaryResultset& res) +{ + for (bool ok = res.more_data(); ok; ok = res.retrieve_next()) + { + for (const auto& field: res.values()) + { + std::visit(VariantPrinter(), field); + } + std::cout << "\n"; + } + const auto& ok = res.ok_packet(); + std::cout << "affected_rows=" << ok.affected_rows.value + << ", last_insert_id=" << ok.last_insert_id.value + << ", warnings=" << ok.warnings + << ", info=" << ok.info.value << endl; +} + int main() { // Basic @@ -53,22 +70,14 @@ int main() }); // Prepare a statement - - /*mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( - stream, "SELECT first_name, age FROM users WHERE last_name = ?") }; - stmt.execute(string_lenenc {"user"});*/ - - - mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( - stream, "SELECT * from users WHERE age < ? and first_name <> ?") }; + auto stmt = mysql::PreparedStatement::prepare( + stream, "SELECT * from users WHERE age < ? and first_name <> ?"); auto res = stmt.execute(40, string_lenenc{"hola"}); - for (const auto& row: res) - { - for (const auto& field: row.values()) - { - std::visit(VariantPrinter(), field); - } - std::cout << "\n"; - } - + print(res); + auto make_older = mysql::PreparedStatement::prepare(stream, "UPDATE users SET age = age + 1"); + res = make_older.execute(); + print(res); + res = stmt.execute(40, string_lenenc{"hola"}); + cout << "\n\n"; + print(res); } diff --git a/src/message_serialization.cpp b/src/message_serialization.cpp index e973c35e..94249109 100644 --- a/src/message_serialization.cpp +++ b/src/message_serialization.cpp @@ -142,16 +142,19 @@ void mysql::serialize(DynamicBuffer& buffer, const StmtExecute& value) serialize(buffer, value.flags); serialize(buffer, int4(1)); // iteration_count - // NULL bitmap - if (!value.param_values.empty()) + if (value.num_params > 0) { - StmtExecuteNullBitmapTraits traits { value.param_values.size() }; + // NULL bitmap + StmtExecuteNullBitmapTraits traits { value.num_params }; std::vector null_bitmap (traits.byte_count(), 0); - for (std::size_t i = 0; i < value.param_values.size(); ++i) + if (value.new_params_bind_flag) { - if (std::holds_alternative(value.param_values[i])) + for (std::size_t i = 0; i < value.param_values.size(); ++i) { - null_bitmap[traits.byte_pos(i)] |= (1 << traits.bit_pos(i)); + if (std::holds_alternative(value.param_values[i])) + { + traits.set_null(null_bitmap.data(), i); + } } } buffer.add(null_bitmap.data(), null_bitmap.size()); diff --git a/src/prepared_statement.cpp b/src/prepared_statement.cpp index 84ad6c7e..6c9424f3 100644 --- a/src/prepared_statement.cpp +++ b/src/prepared_statement.cpp @@ -2,13 +2,22 @@ #include "prepared_statement.hpp" #include "message_serialization.hpp" #include "null_bitmap.hpp" +#include using namespace std; -static void read_param(mysql::MysqlStream& stream, mysql::ParamDefinition& output) +static std::vector read_fields( + mysql::MysqlStream& stream, + std::size_t quantity +) { - stream.read(output.packet); - mysql::deserialize(output.packet, output.value); + std::vector res (quantity); + for (auto& elm: res) + { + stream.read(elm.packet); + mysql::deserialize(elm.packet, elm.value); + } + return res; } template @@ -74,27 +83,27 @@ static mysql::BinaryValue deserialize_field( } } -mysql::BinaryResultsetRow::BinaryResultsetRow( - const std::vector& fields, - std::vector&& packet -) : - fields_ {fields}, - packet_ {std::move(packet)} +static void deserialize_binary_row( + const std::vector& packet, + const std::vector& fields, + std::vector& output +) { - values_.reserve(fields_.size()); - ResultsetRowNullBitmapTraits traits {fields_.size()}; - ReadIterator null_bitmap_first = packet_.data() + 1; // Skip header - ReadIterator first = null_bitmap_first + traits.byte_count(); - ReadIterator last = packet_.data() + packet_.size(); + output.clear(); + output.reserve(fields.size()); + mysql::ResultsetRowNullBitmapTraits traits {fields.size()}; + mysql::ReadIterator null_bitmap_first = packet.data() + 1; // skip header + mysql::ReadIterator current = null_bitmap_first + traits.byte_count(); + mysql::ReadIterator last = packet.data() + packet.size(); - for (std::size_t i = 0; i < fields_.size(); ++i) + for (std::size_t i = 0; i < fields.size(); ++i) { if (traits.is_null(null_bitmap_first, i)) - values_.emplace_back(nullptr); + output.emplace_back(nullptr); else - values_.push_back(deserialize_field(fields_[i].value.type, first, last)); + output.push_back(deserialize_field(fields[i].value.type, current, last)); } - if (first != last) + if (current != last) { throw std::out_of_range {"Leftover data after binary row"}; } @@ -132,46 +141,87 @@ mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, deserialize(read_buffer.data() + 1, read_buffer.data() + read_buffer.size(), response); // Read the parameters and columns if any - std::vector params (response.num_params); - for (int2 i = 0; i < response.num_params; ++i) - read_param(stream, params[i]); - std::vector columns (response.num_columns); - for (int2 i = 0; i < response.num_columns; ++i) - read_param(stream, columns[i]); + auto params = read_fields(stream, response.num_params); + auto fields = read_fields(stream, response.num_columns); - return PreparedStatement {stream, response.statement_id, move(params), move(columns)}; + return PreparedStatement {stream, response.statement_id, move(params), move(fields)}; } -std::vector mysql::PreparedStatement::do_execute(const StmtExecute& message) +void mysql::BinaryResultset::read_metadata() { + stream_->read(current_packet_); + if (get_message_type(current_packet_) == ok_packet_header) // Implicitly checks for errors + { + process_ok(); + } + else + { + // Header containing number of fields + StmtExecuteResponseHeader response_header; + deserialize(current_packet_, response_header); + + // Fields + fields_ = read_fields(*stream_, response_header.num_fields); + + // First row + retrieve_next(); + } +} + +void mysql::BinaryResultset::process_ok() +{ + deserialize(current_packet_.data() + 1, + current_packet_.data() + current_packet_.size(), + ok_packet_); + if (ok_packet_.status_flags & SERVER_STATUS_CURSOR_EXISTS) + { + // TODO: handle cursor semantics + } + state_ = State::exhausted; +} + +bool mysql::BinaryResultset::retrieve_next() +{ + if (state_ == State::exhausted) + return false; + + stream_->read(current_packet_); + auto msg_type = get_message_type(current_packet_); + if (msg_type == eof_packet_header) + { + process_ok(); + } + else + { + deserialize_binary_row(current_packet_, fields_, current_values_); + state_ = State::data_available; + } + return more_data(); +} + +const mysql::OkPacket& mysql::BinaryResultset::ok_packet() const +{ + assert(state_ == State::exhausted); + return ok_packet_; +} + +const std::vector& mysql::BinaryResultset::values() const +{ + assert(state_ == State::data_available); + return current_values_; +} + +mysql::BinaryResultset mysql::PreparedStatement::do_execute(const StmtExecute& message) +{ + std::vector read_buffer; + // TODO: other cursor types DynamicBuffer write_buffer; serialize(write_buffer, message); stream_->reset_sequence_number(); stream_->write(write_buffer.get()); - // Execute response header - std::vector read_buffer; - stream_->read(read_buffer); - StmtExecuteResponseHeader response_header; - deserialize(read_buffer, response_header); - - // Read the parameters. Ignore the packets - for (int1 i = 0; i < response_header.num_fields; ++i) - stream_->read(read_buffer); - - // Read the result - std::vector res; - - while (true) - { - read_buffer.clear(); - stream_->read(read_buffer); - auto msg_type = get_message_type(read_buffer); - if (msg_type == eof_packet_header) - break; - res.emplace_back(columns_, std::move(read_buffer)); - } - - return res; + return mysql::BinaryResultset {*stream_}; } + +