diff --git a/include/basic_serialization.hpp b/include/basic_serialization.hpp index 310ef90b..d637a96e 100644 --- a/include/basic_serialization.hpp +++ b/include/basic_serialization.hpp @@ -25,6 +25,10 @@ template <> struct get_size { static constexpr std::size_t value = 3; }; template <> struct get_size { static constexpr std::size_t value = 4; }; template <> struct get_size { static constexpr std::size_t value = 6; }; template <> struct get_size { static constexpr std::size_t value = 8; }; +template <> struct get_size { static constexpr std::size_t value = 1; }; +template <> struct get_size { static constexpr std::size_t value = 2; }; +template <> struct get_size { static constexpr std::size_t value = 4; }; +template <> struct get_size { static constexpr std::size_t value = 8; }; template struct get_size> { static constexpr std::size_t value = size; }; template constexpr std::size_t get_size_v = get_size::value; diff --git a/include/message_serialization.hpp b/include/message_serialization.hpp index 777ef59d..2b0ab6dc 100644 --- a/include/message_serialization.hpp +++ b/include/message_serialization.hpp @@ -19,14 +19,16 @@ ReadIterator deserialize(ReadIterator from, ReadIterator last, ErrPacket& output ReadIterator deserialize(ReadIterator from, ReadIterator last, Handshake& output); void serialize(DynamicBuffer& buffer, const HandshakeResponse& value); -// Command phase, general +// Resultsets ReadIterator deserialize(ReadIterator from, ReadIterator last, ColumnDefinition& output); +void serialize_binary_value(DynamicBuffer& buffer, const BinaryValue& value); // Prepared statements void serialize(DynamicBuffer& buffer, const StmtPrepare& value); ReadIterator deserialize(ReadIterator from, ReadIterator last, StmtPrepareResponseHeader& output); -void serialize(DynamicBuffer& buffer, const BinaryValue& value); void serialize(DynamicBuffer& buffer, const StmtExecute& value); +ReadIterator deserialize(ReadIterator from, ReadIterator last, StmtExecuteResponseHeader& output); +std::pair compute_field_type(const BinaryValue&); // Text serialization std::ostream& operator<<(std::ostream& os, const Handshake& value); diff --git a/include/messages.hpp b/include/messages.hpp index 6bde7628..037036b0 100644 --- a/include/messages.hpp +++ b/include/messages.hpp @@ -212,22 +212,19 @@ struct StmtPrepareResponseHeader }; using BinaryValue = std::variant< + std::int8_t, + std::int16_t, + std::int32_t, + std::int64_t, + std::uint8_t, + std::uint16_t, + std::uint32_t, + std::uint64_t, string_lenenc, - int8, - int4, - int2, - int1, + std::nullptr_t // NULL // TODO: double, float, dates/times - nullptr_t // aka NULL >; -struct StmtParamValue -{ - FieldType field_type; - bool is_signed; - BinaryValue value; -}; - struct StmtExecute { //int1 message_type: COM_STMT_EXECUTE @@ -235,7 +232,12 @@ struct StmtExecute int1 flags; // int4 iteration_count: always 1 int1 new_params_bind_flag; - std::vector param_values; + std::vector param_values; +}; + +struct StmtExecuteResponseHeader +{ + int1 num_fields; }; diff --git a/include/null_bitmap.hpp b/include/null_bitmap.hpp index 37cac8b4..c6f0332e 100644 --- a/include/null_bitmap.hpp +++ b/include/null_bitmap.hpp @@ -15,6 +15,10 @@ public: std::size_t byte_count() const { return (num_fields_ + 7 + offset) / 8; } std::size_t byte_pos(std::size_t field_pos) const { return (field_pos + offset) / 8; } std::size_t bit_pos(std::size_t field_pos) const { return (field_pos + offset) % 8; } + bool is_null(ReadIterator null_bitmap_begin, std::size_t field_pos) const + { + return null_bitmap_begin[byte_pos(field_pos)] & (1 >> bit_pos(field_pos)); + } }; using StmtExecuteNullBitmapTraits = NullBitmapTraits<0>; diff --git a/include/prepared_statement.hpp b/include/prepared_statement.hpp index 3df19d30..9fcabd4f 100644 --- a/include/prepared_statement.hpp +++ b/include/prepared_statement.hpp @@ -4,6 +4,7 @@ #include "messages.hpp" #include "mysql_stream.hpp" #include +#include namespace mysql { @@ -14,6 +15,17 @@ struct ParamDefinition ColumnDefinition value; }; +class BinaryResultsetRow +{ + const std::vector& fields_; + std::vector packet_; + std::vector values_; +public: + BinaryResultsetRow(const std::vector& fields, std::vector&& packet); + const std::vector& fields() const { return fields_; } + const std::vector& values() const { return values_; } +}; + class PreparedStatement { MysqlStream* stream_; @@ -21,24 +33,23 @@ class PreparedStatement std::vector params_; std::vector columns_; - void do_execute(const StmtExecute& message); + std::vector do_execute(const StmtExecute& message); public: PreparedStatement(MysqlStream& stream, int4 statement_id, - std::vector&& params, std::vector&& columns): - stream_ {&stream}, statement_id_ {statement_id}, params_{std::move(params)}, - columns_ {std::move(columns)} {}; + std::vector&& params, std::vector&& columns); PreparedStatement(const PreparedStatement&) = delete; PreparedStatement(PreparedStatement&&) = default; const PreparedStatement& operator=(const PreparedStatement&) = delete; PreparedStatement& operator=(PreparedStatement&&) = default; ~PreparedStatement() = default; + MysqlStream& next_layer() const { return *stream_; } int4 statement_id() const { return statement_id_; } const std::vector& params() const { return params_; } const std::vector& columns() const { return columns_; } template - void execute(Params&&... params); + std::vector execute(Params&&... params); // execute(Something&&... params): StatementResponse // execute() // uses already-bound params // close(Connection) @@ -52,54 +63,32 @@ public: namespace detail { -std::size_t field_type_to_variant_index(FieldType value); - -template -void set_param_value(const ParamDefinition& definition, StmtParamValue& output, Param&& param) -{ - FieldType type = definition.value.type; - output.field_type = type; - output.is_signed = true; // TODO: where can we take this from? - output.value = BinaryValue(std::forward(param)); - if (output.value.index() != field_type_to_variant_index(type)) - { - throw std::invalid_argument {"Wrong parameter type passed to prepared statement"}; - } -} +inline void fill_execute_msg_impl(std::vector::iterator) {} template void fill_execute_msg_impl( - std::vector::const_iterator param_def, - std::vector::iterator param_output, + std::vector::iterator param_output, Param0&& param, Params&&... tail ) { - set_param_value(*param_def, *param_output, std::forward(param)); - fill_execute_msg_impl(std::next(param_def), std::next(param_output), std::forward(tail)...); + *param_output = std::forward(param); + fill_execute_msg_impl(std::next(param_output), std::forward(tail)...); } -template -void fill_execute_msg_impl( - std::vector::const_iterator param_def, - std::vector::iterator param_output, - Param&& param -) -{ - set_param_value(*param_def, *param_output, std::forward(param)); -} + template -void fill_execute_msg(const std::vector& param_defs, StmtExecute& output, Args&&... args) +void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... args) { - if (sizeof...(args) != param_defs.size()) + if (sizeof...(args) != num_params) { throw std::out_of_range {"Wrong number of parameters passed to prepared statement"}; } output.new_params_bind_flag = 1; - output.param_values.resize(param_defs.size()); - fill_execute_msg_impl(param_defs.begin(), output.param_values.begin(), std::forward(args)...); + output.param_values.resize(num_params); + fill_execute_msg_impl(output.param_values.begin(), std::forward(args)...); } } @@ -108,15 +97,15 @@ void fill_execute_msg(const std::vector& param_defs, StmtExecut template -void mysql::PreparedStatement::execute(Params&&... actual_params) +std::vector mysql::PreparedStatement::execute(Params&&... actual_params) { StmtExecute message { statement_id_, - 0 // TODO: what is this parameter? cursor type?? + 0 // Cursor type: no cursor. TODO: allow execution with different cursor types }; - detail::fill_execute_msg(params_, message, std::forward(actual_params)...); - do_execute(message); + detail::fill_execute_msg(message, params_.size(), std::forward(actual_params)...); + return do_execute(message); } diff --git a/main.cpp b/main.cpp index b52c9e9d..f135e8cf 100644 --- a/main.cpp +++ b/main.cpp @@ -11,6 +11,15 @@ using namespace mysql; constexpr auto HOSTNAME = "localhost"sv; constexpr auto PORT = "3306"sv; +struct VariantPrinter +{ + template + void operator()(T v) const { cout << v << ", "; } + + void operator()(string_lenenc v) const { (*this)(v.value); } + void operator()(std::nullptr_t) const { (*this)("NULL"); } +}; + int main() { // Basic @@ -45,9 +54,21 @@ int main() // Prepare a statement - mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( + /*mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( stream, "SELECT first_name, age FROM users WHERE last_name = ?") }; - stmt.execute(string_lenenc {"user"}); + stmt.execute(string_lenenc {"user"});*/ + mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( + stream, "SELECT * from users WHERE age < ? and first_name <> ?") }; + auto res = stmt.execute(22, string_lenenc{"hola"}); + for (const auto& row: res) + { + for (const auto& field: row.values()) + { + std::visit(VariantPrinter(), field); + } + std::cout << "\n"; + } + } diff --git a/src/message_serialization.cpp b/src/message_serialization.cpp index c6bd30a2..e973c35e 100644 --- a/src/message_serialization.cpp +++ b/src/message_serialization.cpp @@ -89,7 +89,7 @@ void mysql::serialize(DynamicBuffer& buffer, const HandshakeResponse& value) serialize(buffer, value.client_plugin_name); } -// Command phase, general +// Resultsets mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, ColumnDefinition& output) { int_lenenc length_fixed_length_fields; @@ -108,6 +108,13 @@ mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, Col return from; } +void mysql::serialize_binary_value(DynamicBuffer& buffer, const BinaryValue& value) +{ + std::visit([&buffer](auto v) { + mysql::serialize(buffer, v); + }, value); +} + // Prepared statements void mysql::serialize(DynamicBuffer& buffer, const StmtPrepare& value) { @@ -128,13 +135,6 @@ mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, Stm return from; } -void mysql::serialize(DynamicBuffer& buffer, const BinaryValue& value) -{ - std::visit([&buffer](auto v) { - mysql::serialize(buffer, v); - }, value); -} - void mysql::serialize(DynamicBuffer& buffer, const StmtExecute& value) { serialize(buffer, Command::COM_STMT_EXECUTE); @@ -149,7 +149,7 @@ void mysql::serialize(DynamicBuffer& buffer, const StmtExecute& value) std::vector null_bitmap (traits.byte_count(), 0); for (std::size_t i = 0; i < value.param_values.size(); ++i) { - if (value.param_values[i].field_type == FieldType::NULL_) + if (std::holds_alternative(value.param_values[i])) { null_bitmap[traits.byte_pos(i)] |= (1 << traits.bit_pos(i)); } @@ -162,18 +162,51 @@ void mysql::serialize(DynamicBuffer& buffer, const StmtExecute& value) { for (const auto& param: value.param_values) { - serialize(buffer, param.field_type); - serialize(buffer, int1(param.is_signed ? 0 : 0x80)); + auto type = compute_field_type(param); + serialize(buffer, type.first); + serialize(buffer, int1(type.second ? 0 : 0x80)); } for (const auto& param: value.param_values) { - serialize(buffer, param.value); + serialize_binary_value(buffer, param); } } } } +mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, StmtExecuteResponseHeader& output) +{ + // TODO: int1 status: must be 0 to be deserialized as part of this? + return deserialize(from, last, output.num_fields); +} + + +// TODO: refactor this +#define MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(type, typenum, issigned) \ + template <> \ + constexpr std::pair \ + compute_field_type_impl() { return { mysql::FieldType::typenum, issigned }; }; + +template constexpr std::pair compute_field_type_impl(); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::int8_t, TINY, true); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::uint8_t, TINY, false); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::int16_t, SHORT, true); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::uint16_t, SHORT, false); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::int32_t, LONG, true); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::uint32_t, LONG, false); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::int64_t, LONGLONG, true); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::uint64_t, LONGLONG, false); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(mysql::string_lenenc, STRING, true); +MYSQL_COMPUTE_FIELD_TYPE_IMPL_ENTRY(std::nullptr_t, NULL_, true); + +std::pair mysql::compute_field_type(const BinaryValue& v) +{ + return std::visit([](auto elm) { + return compute_field_type_impl(); + }, v); +} + // Text serialization std::ostream& mysql::operator<<(std::ostream& os, const Handshake& value) { diff --git a/src/prepared_statement.cpp b/src/prepared_statement.cpp index 055937c0..724a147c 100644 --- a/src/prepared_statement.cpp +++ b/src/prepared_statement.cpp @@ -1,6 +1,7 @@ #include "prepared_statement.hpp" #include "message_serialization.hpp" +#include "null_bitmap.hpp" using namespace std; @@ -10,6 +11,108 @@ static void read_param(mysql::MysqlStream& stream, mysql::ParamDefinition& outpu mysql::deserialize(output.packet, output.value); } +template +static mysql::BinaryValue deserialize_field_impl( + mysql::ReadIterator& first, + mysql::ReadIterator last +) +{ + T value; + first = mysql::deserialize(first, last, value); + return mysql::BinaryValue {value}; +} + +mysql::BinaryValue not_implemented() +{ + throw std::runtime_error{"Not implemented"}; +} + +static mysql::BinaryValue deserialize_field( + mysql::FieldType type, + mysql::ReadIterator& first, + mysql::ReadIterator last +) +{ + switch (type) + { + case mysql::FieldType::DECIMAL: + case mysql::FieldType::VARCHAR: + case mysql::FieldType::BIT: + case mysql::FieldType::NEWDECIMAL: + case mysql::FieldType::ENUM: + case mysql::FieldType::SET: + case mysql::FieldType::TINY_BLOB: + case mysql::FieldType::MEDIUM_BLOB: + case mysql::FieldType::LONG_BLOB: + case mysql::FieldType::BLOB: + case mysql::FieldType::VAR_STRING: + case mysql::FieldType::STRING: + case mysql::FieldType::GEOMETRY: + return deserialize_field_impl(first, last); + case mysql::FieldType::TINY: + return deserialize_field_impl(first, last); + case mysql::FieldType::SHORT: + return deserialize_field_impl(first, last); + case mysql::FieldType::INT24: + case mysql::FieldType::LONG: + return deserialize_field_impl(first, last); + case mysql::FieldType::LONGLONG: + return deserialize_field_impl(first, last); + case mysql::FieldType::FLOAT: + return not_implemented(); + case mysql::FieldType::DOUBLE: + return not_implemented(); + case mysql::FieldType::NULL_: + return deserialize_field_impl(first, last); + case mysql::FieldType::TIMESTAMP: + case mysql::FieldType::DATE: + case mysql::FieldType::TIME: + case mysql::FieldType::DATETIME: + case mysql::FieldType::YEAR: + default: + return not_implemented(); + } +} + +mysql::BinaryResultsetRow::BinaryResultsetRow( + const std::vector& fields, + std::vector&& packet +) : + fields_ {fields}, + packet_ {std::move(packet)} +{ + values_.reserve(fields_.size()); + StmtExecuteNullBitmapTraits traits {fields_.size()}; + ReadIterator null_bitmap_first = packet_.data() + 1; // Skip header + ReadIterator first = null_bitmap_first + traits.byte_count(); + ReadIterator last = packet_.data() + packet_.size(); + + for (std::size_t i = 0; i < fields_.size(); ++i) + { + if (traits.is_null(null_bitmap_first, i)) + values_.emplace_back(nullptr); + else + values_.push_back(deserialize_field(fields_[i].value.type, first, last)); + } + if (first != last) + { + throw std::out_of_range {"Leftover data after binary row"}; + } +} + +mysql::PreparedStatement::PreparedStatement( + MysqlStream& stream, + int4 statement_id, + std::vector&& params, + std::vector&& columns +) : + stream_ {&stream}, + statement_id_ {statement_id}, + params_ {std::move(params)}, + columns_ {std::move(columns)} +{ +}; + mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, std::string_view query) { // Write the prepare request @@ -39,47 +142,36 @@ mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, return PreparedStatement {stream, response.statement_id, move(params), move(columns)}; } -void mysql::PreparedStatement::do_execute(const StmtExecute& message) +std::vector mysql::PreparedStatement::do_execute(const StmtExecute& message) { + // TODO: other cursor types DynamicBuffer write_buffer; serialize(write_buffer, message); stream_->reset_sequence_number(); stream_->write(write_buffer.get()); + + // Execute response header std::vector read_buffer; stream_->read(read_buffer); - // TODO: do sth with response -} + StmtExecuteResponseHeader response_header; + deserialize(read_buffer, response_header); -std::size_t mysql::detail::field_type_to_variant_index(FieldType value) -{ - switch (value) + // Read the parameters. Ignore the packets + for (int1 i = 0; i < response_header.num_fields; ++i) + stream_->read(read_buffer); + + // Read the result + std::vector res; + + while (true) { - case FieldType::STRING: - case FieldType::VARCHAR: - case FieldType::VAR_STRING: - case FieldType::ENUM: - case FieldType::SET: - case FieldType::LONG_BLOB: - case FieldType::MEDIUM_BLOB: - case FieldType::BLOB: - case FieldType::TINY_BLOB: - case FieldType::GEOMETRY: - case FieldType::BIT: - case FieldType::DECIMAL: - case FieldType::NEWDECIMAL: - return 0; // TODO: this is not very good - case FieldType::LONGLONG: - return 1; - case FieldType::LONG: - case FieldType::INT24: - return 2; - case FieldType::SHORT: - case FieldType::YEAR: - return 3; - case FieldType::TINY: - return 4; - case FieldType::NULL_: - return 5; - default: throw std::logic_error {"Not implemented"}; + read_buffer.clear(); + stream_->read(read_buffer); + auto msg_type = get_message_type(read_buffer); + if (msg_type == eof_packet_header) + break; + res.emplace_back(columns_, std::move(read_buffer)); } + + return res; }