diff --git a/TODO.txt b/TODO.txt index 8b546e38..ed0f91bb 100644 --- a/TODO.txt +++ b/TODO.txt @@ -1,4 +1,7 @@ Prepared statements + Add tests for (de)serialization of floats + Rename (De)SerializationContext + Zero dates? Multiresultset Text protocol Binary protocol (stored procedures) diff --git a/include/mysql/impl/binary_serialization.hpp b/include/mysql/impl/binary_serialization.hpp new file mode 100644 index 00000000..c83783b5 --- /dev/null +++ b/include/mysql/impl/binary_serialization.hpp @@ -0,0 +1,34 @@ +#ifndef INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_HPP_ +#define INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_HPP_ + +#include "mysql/impl/serialization.hpp" +#include "mysql/value.hpp" + +// (de)serialization overloads for date/time types and mysql::value + +namespace mysql +{ +namespace detail +{ + +inline std::size_t get_size(const date& input, const SerializationContext& ctx) noexcept; +inline void serialize(const date& input, SerializationContext& ctx) noexcept; +inline Error deserialize(date& output, DeserializationContext& ctx) noexcept; + +inline std::size_t get_size(const datetime& input, const SerializationContext& ctx) noexcept; +inline void serialize(const datetime& input, SerializationContext& ctx) noexcept; +inline Error deserialize(datetime& output, DeserializationContext& ctx) noexcept; + +inline std::size_t get_size(const time& input, const SerializationContext& ctx) noexcept; +inline void serialize(const time& input, SerializationContext& ctx) noexcept; +inline Error deserialize(time& output, DeserializationContext& ctx) noexcept; + +inline std::size_t get_size(const value& input, const SerializationContext& ctx) noexcept; +inline void serialize(const value& input, SerializationContext& ctx) noexcept; + +} +} + +#include "mysql/impl/binary_serialization.ipp" + +#endif /* INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_HPP_ */ diff --git a/include/mysql/impl/binary_serialization.ipp b/include/mysql/impl/binary_serialization.ipp new file mode 100644 index 00000000..9183c6e0 --- /dev/null +++ b/include/mysql/impl/binary_serialization.ipp @@ -0,0 +1,343 @@ +#ifndef INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_IPP_ +#define INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_IPP_ + +#include + +namespace mysql +{ +namespace detail +{ + +// Performs a mapping from T to a type that can be serialized +template +struct get_serializable_type { using type = T; }; + +template +using get_serializable_type_t = typename get_serializable_type::type; + +template <> struct get_serializable_type { using type = int4; }; +template <> struct get_serializable_type { using type = int4_signed; }; +template <> struct get_serializable_type { using type = int8; }; +template <> struct get_serializable_type { using type = int8_signed; }; +template <> struct get_serializable_type { using type = value_holder; }; +template <> struct get_serializable_type { using type = value_holder; }; +template <> struct get_serializable_type { using type = string_lenenc; }; +template <> struct get_serializable_type { using type = int2; }; +template <> struct get_serializable_type { using type = dummy_serializable; }; + +template +inline get_serializable_type_t to_serializable_type(T input) noexcept +{ + return get_serializable_type_t(input); +} + +template <> +inline get_serializable_type_t to_serializable_type(year input) noexcept +{ + return get_serializable_type_t(static_cast(input)); +} + +inline Error deserialize_binary_date(date& output, std::uint8_t length, DeserializationContext& ctx) noexcept +{ + int2 year; + int1 month; + int1 day; + + if (length >= 4) // if length is zero, year, month and day are zero + { + auto err = deserialize_fields(ctx, year, month, day); + if (err != Error::ok) return err; + } + + // TODO: how does this handle zero dates? + ::date::year_month_day ymd (::date::year(year.value), ::date::month(month.value), ::date::day(day.value)); + output = date(ymd); + return Error::ok; +} + +// Does not add the length prefix byte +inline void serialize_binary_ymd( + const ::date::year_month_day& ymd, + SerializationContext& ctx +) noexcept +{ + serialize_fields( + ctx, + int2(static_cast(ymd.year())), + int1(static_cast(ymd.month())), + int1(static_cast(ymd.day())) + ); +} + +struct broken_datetime +{ + ::date::year_month_day ymd; + ::date::time_of_day tod; + + broken_datetime(const datetime& input) noexcept : + ymd(::date::floor<::date::days>(input)), + tod(input - ::date::floor<::date::days>(input)) + { + } + + // Doesn't count the first length byte + std::uint8_t binary_serialized_length() const noexcept + { + std::uint8_t res = 11; // base length + if (tod.subseconds() == 0) + { + res -= 4; + if (tod.seconds() == 0 && + tod.minutes() == 0 && + tod.hours() == 0) + { + res -= 4; + } + } + return res; + } +}; + +struct broken_time +{ + ::date::days days; + std::chrono::hours hours; + std::chrono::minutes minutes; + std::chrono::seconds seconds; + std::chrono::microseconds microseconds; + + broken_time(const time& input) noexcept : + days(std::chrono::duration_cast<::date::days>(input)), + hours(std::chrono::duration_cast(input % ::date::days(1))), + minutes(std::chrono::duration_cast(input % std::chrono::hours(1))), + seconds(std::chrono::duration_cast(input % std::chrono::minutes(1))), + microseconds(input % std::chrono::seconds(1)) + { + } + + // Doesn't count the first length byte + std::uint8_t binary_serialized_length() const noexcept + { + std::uint8_t res = 12; + if (microseconds == 0) + { + res -= 4; + if (seconds == 0 && minutes == 0 && hours == 0 && days == 0) + { + res -= 8; + } + } + return res; + } +}; + +} +} + +// date +inline std::size_t mysql::detail::get_size( + const date&, + const SerializationContext& +) noexcept +{ + // TODO: consider zero dates? + return 5; // length, year, month, day +} + +inline void mysql::detail::serialize( + const date& input, + SerializationContext& ctx +) noexcept +{ + // TODO: consider zero dates? + serialize(int1(4), ctx); // + serialize_binary_ymd(::date::year_month_day (input), ctx); +} + +inline mysql::Error mysql::detail::deserialize( + date& output, + DeserializationContext& ctx +) noexcept +{ + int1 length; + auto err = deserialize(length, ctx); + if (err != Error::ok) return err; + return deserialize_binary_date(output, length.value, ctx); +} + +// datetime +inline std::size_t mysql::detail::get_size( + const datetime& input, + const SerializationContext& +) noexcept +{ + return broken_datetime(input).binary_serialized_length() + 1; // extra length prefix byte +} + +inline void mysql::detail::serialize( + const datetime& input, + SerializationContext& ctx +) noexcept +{ + broken_datetime brokendt (input); + auto length = brokendt.binary_serialized_length(); + serialize(int1(length), ctx); + if (length >= 4) // TODO: refactor these magic constants + { + serialize_binary_ymd(brokendt.ymd, ctx); + } + if (length >= 7) + { + serialize_fields( + ctx, + int1(brokendt.tod.hours().count()), + int1(brokendt.tod.minutes().count()), + int1(brokendt.tod.seconds().count()) + ); + } + if (length >= 11) + { + serialize(int4(brokendt.tod.subseconds().count()), ctx); + } +} + +inline mysql::Error mysql::detail::deserialize( + datetime& output, + DeserializationContext& ctx +) noexcept +{ + int1 length; + date date_part; + int1 hours; + int1 minutes; + int1 seconds; + int4 micros; + + // Deserialize length + auto err = deserialize(length, ctx); + if (err != Error::ok) return err; + + // Based on length, deserialize the rest of the fields + err = deserialize_binary_date(date_part, length.value, ctx); + if (err != Error::ok) return err; + if (length.value >= 7) + { + err = deserialize_fields(ctx, hours, minutes, seconds); + if (err != Error::ok) return err; + } + if (length.value >= 11) + { + err = deserialize(micros, ctx); + if (err != Error::ok) return err; + } + + // Compose the final datetime + output = date_part + std::chrono::hours(hours.value) + std::chrono::minutes(minutes.value) + + std::chrono::seconds(seconds.value) + std::chrono::microseconds(micros.value); + return Error::ok; +} + +// time +inline std::size_t mysql::detail::get_size( + const time& input, + const SerializationContext& +) noexcept +{ + return broken_time(input).binary_serialized_length() + 1; // length byte +} + +inline void mysql::detail::serialize( + const time& input, + SerializationContext& ctx +) noexcept +{ + broken_time broken (input); + auto length = broken.binary_serialized_length(); + serialize(int1(length), ctx); + if (length >= 8) // TODO: magic constants + { + int1 is_negative (input.count() < 0 ? 1 : 0); + serialize_fields( + ctx, + is_negative, + int4(broken.days.count()), + int1(broken.hours.count()), + int1(broken.minutes.count()), + int1(broken.seconds.count()) + ); + } + if (length >= 12) + { + serialize(int4(broken.microseconds.count()), ctx); + } +} + +inline mysql::Error mysql::detail::deserialize( + time& output, + DeserializationContext& ctx +) noexcept +{ + // Length + int1 length; + auto err = deserialize(length, ctx); + if (err != Error::ok) return err; + + int1 is_negative (0); + int4 days (0); + int1 hours (0); + int1 minutes(0); + int1 seconds(0); + int4 microseconds(0); + + if (length.value >= 8) + { + err = deserialize_fields( + ctx, + is_negative, + days, + hours, + minutes, + seconds + ); + if (err != Error::ok) return err; + } + if (length.value >= 12) + { + err = deserialize(microseconds, ctx); + if (err != Error::ok) return err; + } + + output = (is_negative.value ? -1 : 1) * ( + ::date::days(days.value) + + std::chrono::hours(hours.value) + + std::chrono::minutes(minutes.value) + + std::chrono::seconds(seconds.value) + + std::chrono::microseconds(microseconds.value) + ); + return Error::ok; +} + +// mysql::value +inline std::size_t mysql::detail::get_size( + const value& input, + const SerializationContext& ctx +) noexcept +{ + return std::visit([&ctx](const auto& v) { + return get_size(to_serializable_type(v), ctx); + }, input); +} + +inline void mysql::detail::serialize( + const value& input, + SerializationContext& ctx +) noexcept +{ + std::visit([&ctx](const auto& v) { + serialize(to_serializable_type(v), ctx); + }, input); +} + + + +#endif /* INCLUDE_MYSQL_IMPL_BINARY_SERIALIZATION_IPP_ */ diff --git a/include/mysql/impl/messages.hpp b/include/mysql/impl/messages.hpp index 2d18f339..c5638e34 100644 --- a/include/mysql/impl/messages.hpp +++ b/include/mysql/impl/messages.hpp @@ -2,13 +2,14 @@ #define MYSQL_ASIO_IMPL_MESSAGES_HPP #include "mysql/impl/serialization.hpp" +#include "mysql/impl/basic_types.hpp" +#include "mysql/impl/constants.hpp" +#include "mysql/collation.hpp" +#include "mysql/value.hpp" #include #include #include #include -#include "mysql/impl/basic_types.hpp" -#include "mysql/impl/constants.hpp" -#include "mysql/collation.hpp" namespace mysql { @@ -243,6 +244,35 @@ struct get_struct_fields ); }; +struct com_stmt_execute_packet +{ + int4 statement_id; + int1 flags; + int4 iteration_count; + // int1 num_params; implicit: from the iterator distance + // if num_params > 0: NULL bitmap + int1 new_params_bind_flag; + const value* params_begin; // TODO: maybe change to a generic iterator + const value* params_end; + + static constexpr std::uint8_t command_id = 0x17; + + struct param_meta + { + protocol_field_type type; + int1 unsigned_flag; + }; +}; + +template <> +struct get_struct_fields +{ + static constexpr auto value = std::make_tuple( + &com_stmt_execute_packet::param_meta::type, + &com_stmt_execute_packet::param_meta::unsigned_flag + ); +}; + // serialization functions inline Error deserialize(ok_packet& output, DeserializationContext& ctx) noexcept; @@ -252,6 +282,10 @@ inline void serialize(const handshake_response_packet& value, SerializationConte inline Error deserialize(auth_switch_request_packet& output, DeserializationContext& ctx) noexcept; inline Error deserialize(column_definition_packet& output, DeserializationContext& ctx) noexcept; inline Error deserialize(com_stmt_prepare_ok_packet& output, DeserializationContext& ctx) noexcept; +inline std::size_t get_size(const com_stmt_execute_packet& value, const SerializationContext& ctx) noexcept; +inline void serialize(const com_stmt_execute_packet& input, SerializationContext& ctx) noexcept; + + // Helper to serialize top-level messages template diff --git a/include/mysql/impl/messages.ipp b/include/mysql/impl/messages.ipp index f7bb0714..cc2c1c03 100644 --- a/include/mysql/impl/messages.ipp +++ b/include/mysql/impl/messages.ipp @@ -2,7 +2,55 @@ #define MYSQL_ASIO_IMPL_MESSAGES_IPP #include "mysql/impl/serialization.hpp" +#include "mysql/impl/null_bitmap_traits.hpp" +#include "mysql/impl/binary_serialization.hpp" #include +#include + +namespace mysql +{ +namespace detail +{ + +// Maps from an actual value to a protocol_field_type. Only value's type is used +inline protocol_field_type get_protocol_field_type( + const value& input +) noexcept +{ + struct visitor + { + constexpr auto operator()(std::int32_t) const noexcept { return protocol_field_type::long_; } + constexpr auto operator()(std::uint32_t) const noexcept { return protocol_field_type::long_; } + constexpr auto operator()(std::int64_t) const noexcept { return protocol_field_type::longlong; } + constexpr auto operator()(std::uint64_t) const noexcept { return protocol_field_type::longlong; } + constexpr auto operator()(std::string_view) const noexcept { return protocol_field_type::var_string; } + constexpr auto operator()(float) const noexcept { return protocol_field_type::float_; } + constexpr auto operator()(double) const noexcept { return protocol_field_type::double_; } + constexpr auto operator()(date) const noexcept { return protocol_field_type::date; } + constexpr auto operator()(datetime) const noexcept { return protocol_field_type::datetime; } + constexpr auto operator()(time) const noexcept { return protocol_field_type::time; } + constexpr auto operator()(year) const noexcept { return protocol_field_type::year; } + constexpr auto operator()(std::nullptr_t) const noexcept { return protocol_field_type::null; } + }; + return std::visit(visitor(), input); +} + +// Whether to include the unsigned flag in the statement execute message +// for a given value or not. Only value's type is used +inline bool is_unsigned( + const value& input +) noexcept +{ + // By default, return false; just for integer types explicitly unsigned return true + return std::visit([](auto v) { + using type = decltype(v); + return std::is_same_v || + std::is_same_v; + }, input); +} + +} +} inline mysql::Error mysql::detail::deserialize( ok_packet& output, @@ -183,6 +231,74 @@ inline mysql::Error mysql::detail::deserialize( ); } +inline std::size_t mysql::detail::get_size( + const com_stmt_execute_packet& value, + const SerializationContext& ctx +) noexcept +{ + std::size_t res = 1 + // command ID + get_size(value.statement_id, ctx) + + get_size(value.flags, ctx) + + get_size(value.iteration_count, ctx) + + 1; // num_params + auto num_params = std::distance(value.params_begin, value.params_end); + assert(num_params >= 0 && num_params <= 255); + res += null_bitmap_traits(stmt_execute_null_bitmap_offset, num_params).byte_count(); + res += get_size(value.new_params_bind_flag, ctx); + res += get_size(com_stmt_execute_packet::param_meta{}, ctx) * num_params; + for (auto it = value.params_begin; it != value.params_end; ++it) + { + res += get_size(*it, ctx); + } + return res; +} + +inline void mysql::detail::serialize( + const com_stmt_execute_packet& input, + SerializationContext& ctx +) noexcept +{ + serialize(int1(com_stmt_execute_packet::command_id), ctx); + serialize(input.statement_id, ctx); + serialize(input.flags, ctx); + serialize(input.iteration_count, ctx); + + // Number of parameters + auto num_params = std::distance(input.params_begin, input.params_end); + assert(num_params >= 0 && num_params <= 255); + serialize(int1(static_cast(num_params)), ctx); + + // NULL bitmap (already size zero if num_params == 0) + null_bitmap_traits traits (stmt_execute_null_bitmap_offset, num_params); + std::size_t i = 0; + for (auto it = input.params_begin; it < input.params_end; ++it, ++i) + { + if (std::holds_alternative(*it)) + { + traits.set_null(ctx.first(), i); + } + } + ctx.advance(traits.byte_count()); + + // new parameters bind flag + serialize(input.new_params_bind_flag, ctx); + + // value metadata + com_stmt_execute_packet::param_meta meta; + for (auto it = input.params_begin; it < input.params_end; ++it) + { + meta.type = get_protocol_field_type(*it); + meta.unsigned_flag.value = is_unsigned(*it) ? 0x80 : 0; + serialize(meta, ctx); + } + + // actual values + for (auto it = input.params_begin; it < input.params_end; ++it) + { + serialize(*it, ctx); + } +} + template void mysql::detail::serialize_message( const Serializable& input, diff --git a/include/mysql/impl/query.ipp b/include/mysql/impl/query.ipp index 86e13506..44b189b4 100644 --- a/include/mysql/impl/query.ipp +++ b/include/mysql/impl/query.ipp @@ -2,7 +2,7 @@ #define MYSQL_ASIO_IMPL_QUERY_IPP #include "mysql/impl/messages.hpp" -#include "mysql/impl/deserialize_row.hpp" +#include "mysql/impl/text_deserialization.hpp" #include "mysql/impl/serialization.hpp" #include #include diff --git a/include/mysql/impl/serialization.hpp b/include/mysql/impl/serialization.hpp index 993e341c..da35b088 100644 --- a/include/mysql/impl/serialization.hpp +++ b/include/mysql/impl/serialization.hpp @@ -98,10 +98,14 @@ private: using value_type = typename get_value_type::type; public: static constexpr bool value = - std::is_integral_v && + std::is_arithmetic_v && // includes floating point types std::is_base_of_v, T>; }; +// Serialization of these types relies on this fact +static_assert(std::numeric_limits::is_iec559); +static_assert(std::numeric_limits::is_iec559); + template <> struct is_fixed_size : std::false_type {}; template struct is_fixed_size>: std::true_type {}; @@ -456,12 +460,12 @@ get_size(const T& input, const SerializationContext& ctx) noexcept return res; } -// Helper to write custom struct deserialize() +// Helper to write custom struct (de)serialize() template -Error deserialize_fields(DeserializationContext& ctx, FirstType& field) { return deserialize(field, ctx); } +Error deserialize_fields(DeserializationContext& ctx, FirstType& field) noexcept { return deserialize(field, ctx); } template -Error deserialize_fields(DeserializationContext& ctx, FirstType& field, Types&... fields_tail) +Error deserialize_fields(DeserializationContext& ctx, FirstType& field, Types&... fields_tail) noexcept { Error err = deserialize(field, ctx); if (err == Error::ok) @@ -471,6 +475,25 @@ Error deserialize_fields(DeserializationContext& ctx, FirstType& field, Types&.. return err; } +template +void serialize_fields(SerializationContext& ctx, const FirstType& field) noexcept { serialize(field, ctx); } + +template +void serialize_fields(SerializationContext& ctx, const FirstType& field, const Types&... fields_tail) +{ + serialize(field, ctx); + serialize_fields(ctx, fields_tail...); +} + +// Dummy type to indicate no (de)serialization is required +struct dummy_serializable +{ + dummy_serializable(...) {} // Make it constructible from anything +}; +inline std::size_t get_size(dummy_serializable, const SerializationContext&) noexcept { return 0; } +inline void serialize(dummy_serializable, SerializationContext&) noexcept {} +inline Error deserialize(dummy_serializable, DeserializationContext&) noexcept { return Error::ok; } + } } diff --git a/include/mysql/impl/deserialize_row.hpp b/include/mysql/impl/text_deserialization.hpp similarity index 91% rename from include/mysql/impl/deserialize_row.hpp rename to include/mysql/impl/text_deserialization.hpp index 537d4558..a6e33089 100644 --- a/include/mysql/impl/deserialize_row.hpp +++ b/include/mysql/impl/text_deserialization.hpp @@ -28,6 +28,6 @@ inline error_code deserialize_text_row( } } -#include "mysql/impl/deserialize_row.ipp" +#include "mysql/impl/text_deserialization.ipp" -#endif \ No newline at end of file +#endif diff --git a/include/mysql/impl/deserialize_row.ipp b/include/mysql/impl/text_deserialization.ipp similarity index 100% rename from include/mysql/impl/deserialize_row.ipp rename to include/mysql/impl/text_deserialization.ipp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 79c665c7..1e751ed9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -20,7 +20,7 @@ add_executable( unit/capabilities.cpp unit/auth.cpp unit/metadata.cpp - unit/deserialize_row.cpp + unit/text_deserialization.cpp unit/value.cpp unit/row.cpp unit/error.cpp diff --git a/test/unit/deserialize_row.cpp b/test/unit/text_deserialization.cpp similarity index 99% rename from test/unit/deserialize_row.cpp rename to test/unit/text_deserialization.cpp index 5d474bf4..4f5dc727 100644 --- a/test/unit/deserialize_row.cpp +++ b/test/unit/text_deserialization.cpp @@ -7,7 +7,7 @@ #include #include -#include "mysql/impl/deserialize_row.hpp" +#include "mysql/impl/text_deserialization.hpp" #include "test_common.hpp" using namespace mysql::detail;