diff --git a/include/basic_serialization.hpp b/include/basic_serialization.hpp index d0168ec2..310ef90b 100644 --- a/include/basic_serialization.hpp +++ b/include/basic_serialization.hpp @@ -88,6 +88,8 @@ deserialize(ReadIterator from, ReadIterator last, T& to) return res; } +inline ReadIterator deserialize(ReadIterator from, ReadIterator last, nullptr_t&) { return from; } + template ReadIterator deserialize(const std::vector& from, T& to) { @@ -168,7 +170,7 @@ serialize(DynamicBuffer& buffer, T value) serialize(buffer, static_cast>(value)); } - +inline void serialize(DynamicBuffer&, nullptr_t) {}; } diff --git a/include/message_serialization.hpp b/include/message_serialization.hpp index a0dfbcec..777ef59d 100644 --- a/include/message_serialization.hpp +++ b/include/message_serialization.hpp @@ -25,6 +25,8 @@ ReadIterator deserialize(ReadIterator from, ReadIterator last, ColumnDefinition& // Prepared statements void serialize(DynamicBuffer& buffer, const StmtPrepare& value); ReadIterator deserialize(ReadIterator from, ReadIterator last, StmtPrepareResponseHeader& output); +void serialize(DynamicBuffer& buffer, const BinaryValue& value); +void serialize(DynamicBuffer& buffer, const StmtExecute& value); // Text serialization std::ostream& operator<<(std::ostream& os, const Handshake& value); diff --git a/include/messages.hpp b/include/messages.hpp index 17b909e8..6bde7628 100644 --- a/include/messages.hpp +++ b/include/messages.hpp @@ -3,6 +3,7 @@ #include #include +#include #include "basic_types.hpp" namespace mysql @@ -210,6 +211,32 @@ struct StmtPrepareResponseHeader // TODO: int1 metadata_follows when CLIENT_OPTIONAL_RESULTSET_METADATA }; +using BinaryValue = std::variant< + string_lenenc, + int8, + int4, + int2, + int1, + // TODO: double, float, dates/times + nullptr_t // aka NULL +>; + +struct StmtParamValue +{ + FieldType field_type; + bool is_signed; + BinaryValue value; +}; + +struct StmtExecute +{ + //int1 message_type: COM_STMT_EXECUTE + int4 statement_id; + int1 flags; + // int4 iteration_count: always 1 + int1 new_params_bind_flag; + std::vector param_values; +}; } diff --git a/include/null_bitmap.hpp b/include/null_bitmap.hpp new file mode 100644 index 00000000..37cac8b4 --- /dev/null +++ b/include/null_bitmap.hpp @@ -0,0 +1,28 @@ +#ifndef INCLUDE_NULL_BITMAP_HPP_ +#define INCLUDE_NULL_BITMAP_HPP_ + +#include + +namespace mysql +{ + +template +class NullBitmapTraits +{ + std::size_t num_fields_; +public: + constexpr NullBitmapTraits(std::size_t num_fields): num_fields_ {num_fields} {}; + std::size_t byte_count() const { return (num_fields_ + 7 + offset) / 8; } + std::size_t byte_pos(std::size_t field_pos) const { return (field_pos + offset) / 8; } + std::size_t bit_pos(std::size_t field_pos) const { return (field_pos + offset) % 8; } +}; + +using StmtExecuteNullBitmapTraits = NullBitmapTraits<0>; +using ResultsetRowNullBitmapTraits = NullBitmapTraits<2>; + + +} + + + +#endif /* INCLUDE_NULL_BITMAP_HPP_ */ diff --git a/include/prepared_statement.hpp b/include/prepared_statement.hpp index a44f8b1e..3df19d30 100644 --- a/include/prepared_statement.hpp +++ b/include/prepared_statement.hpp @@ -3,6 +3,7 @@ #include "messages.hpp" #include "mysql_stream.hpp" +#include namespace mysql { @@ -19,6 +20,8 @@ class PreparedStatement int4 statement_id_; std::vector params_; std::vector columns_; + + void do_execute(const StmtExecute& message); public: PreparedStatement(MysqlStream& stream, int4 statement_id, std::vector&& params, std::vector&& columns): @@ -33,6 +36,9 @@ public: int4 statement_id() const { return statement_id_; } const std::vector& params() const { return params_; } const std::vector& columns() const { return columns_; } + + template + void execute(Params&&... params); // execute(Something&&... params): StatementResponse // execute() // uses already-bound params // close(Connection) @@ -41,8 +47,77 @@ public: static PreparedStatement prepare(MysqlStream& stream, std::string_view query); }; + +// Implementations +namespace detail +{ + +std::size_t field_type_to_variant_index(FieldType value); + +template +void set_param_value(const ParamDefinition& definition, StmtParamValue& output, Param&& param) +{ + FieldType type = definition.value.type; + output.field_type = type; + output.is_signed = true; // TODO: where can we take this from? + output.value = BinaryValue(std::forward(param)); + if (output.value.index() != field_type_to_variant_index(type)) + { + throw std::invalid_argument {"Wrong parameter type passed to prepared statement"}; + } +} + +template +void fill_execute_msg_impl( + std::vector::const_iterator param_def, + std::vector::iterator param_output, + Param0&& param, + Params&&... tail +) +{ + set_param_value(*param_def, *param_output, std::forward(param)); + fill_execute_msg_impl(std::next(param_def), std::next(param_output), std::forward(tail)...); +} + +template +void fill_execute_msg_impl( + std::vector::const_iterator param_def, + std::vector::iterator param_output, + Param&& param +) +{ + set_param_value(*param_def, *param_output, std::forward(param)); } +template +void fill_execute_msg(const std::vector& param_defs, StmtExecute& output, Args&&... args) +{ + if (sizeof...(args) != param_defs.size()) + { + throw std::out_of_range {"Wrong number of parameters passed to prepared statement"}; + } + output.new_params_bind_flag = 1; + output.param_values.resize(param_defs.size()); + fill_execute_msg_impl(param_defs.begin(), output.param_values.begin(), std::forward(args)...); +} + +} + +} + + +template +void mysql::PreparedStatement::execute(Params&&... actual_params) +{ + StmtExecute message + { + statement_id_, + 0 // TODO: what is this parameter? cursor type?? + }; + detail::fill_execute_msg(params_, message, std::forward(actual_params)...); + do_execute(message); +} + #endif /* INCLUDE_PREPARED_STATEMENT_HPP_ */ diff --git a/main.cpp b/main.cpp index 8a8badd0..b52c9e9d 100644 --- a/main.cpp +++ b/main.cpp @@ -40,13 +40,14 @@ int main() CharacterSetLowerByte::utf8_general_ci, "root", "root", - "mysql" + "awesome" }); // Prepare a statement mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( - stream, "SELECT host FROM user WHERE user = ?") }; + stream, "SELECT first_name, age FROM users WHERE last_name = ?") }; + stmt.execute(string_lenenc {"user"}); } diff --git a/src/message_serialization.cpp b/src/message_serialization.cpp index edf36b91..c6bd30a2 100644 --- a/src/message_serialization.cpp +++ b/src/message_serialization.cpp @@ -6,8 +6,10 @@ */ #include "message_serialization.hpp" +#include "null_bitmap.hpp" #include #include +#include using namespace std; @@ -126,6 +128,52 @@ mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, Stm return from; } +void mysql::serialize(DynamicBuffer& buffer, const BinaryValue& value) +{ + std::visit([&buffer](auto v) { + mysql::serialize(buffer, v); + }, value); +} + +void mysql::serialize(DynamicBuffer& buffer, const StmtExecute& value) +{ + serialize(buffer, Command::COM_STMT_EXECUTE); + serialize(buffer, value.statement_id); + serialize(buffer, value.flags); + serialize(buffer, int4(1)); // iteration_count + + // NULL bitmap + if (!value.param_values.empty()) + { + StmtExecuteNullBitmapTraits traits { value.param_values.size() }; + std::vector null_bitmap (traits.byte_count(), 0); + for (std::size_t i = 0; i < value.param_values.size(); ++i) + { + if (value.param_values[i].field_type == FieldType::NULL_) + { + null_bitmap[traits.byte_pos(i)] |= (1 << traits.bit_pos(i)); + } + } + buffer.add(null_bitmap.data(), null_bitmap.size()); + + serialize(buffer, value.new_params_bind_flag); + + if (value.new_params_bind_flag) + { + for (const auto& param: value.param_values) + { + serialize(buffer, param.field_type); + serialize(buffer, int1(param.is_signed ? 0 : 0x80)); + } + + for (const auto& param: value.param_values) + { + serialize(buffer, param.value); + } + } + } +} + // Text serialization std::ostream& mysql::operator<<(std::ostream& os, const Handshake& value) { diff --git a/src/prepared_statement.cpp b/src/prepared_statement.cpp index 3e0eb24e..055937c0 100644 --- a/src/prepared_statement.cpp +++ b/src/prepared_statement.cpp @@ -38,3 +38,48 @@ mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, return PreparedStatement {stream, response.statement_id, move(params), move(columns)}; } + +void mysql::PreparedStatement::do_execute(const StmtExecute& message) +{ + DynamicBuffer write_buffer; + serialize(write_buffer, message); + stream_->reset_sequence_number(); + stream_->write(write_buffer.get()); + std::vector read_buffer; + stream_->read(read_buffer); + // TODO: do sth with response +} + +std::size_t mysql::detail::field_type_to_variant_index(FieldType value) +{ + switch (value) + { + case FieldType::STRING: + case FieldType::VARCHAR: + case FieldType::VAR_STRING: + case FieldType::ENUM: + case FieldType::SET: + case FieldType::LONG_BLOB: + case FieldType::MEDIUM_BLOB: + case FieldType::BLOB: + case FieldType::TINY_BLOB: + case FieldType::GEOMETRY: + case FieldType::BIT: + case FieldType::DECIMAL: + case FieldType::NEWDECIMAL: + return 0; // TODO: this is not very good + case FieldType::LONGLONG: + return 1; + case FieldType::LONG: + case FieldType::INT24: + return 2; + case FieldType::SHORT: + case FieldType::YEAR: + return 3; + case FieldType::TINY: + return 4; + case FieldType::NULL_: + return 5; + default: throw std::logic_error {"Not implemented"}; + } +}