diff --git a/CMakeLists.txt b/CMakeLists.txt index 74672eb9..b7603650 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,36 +8,41 @@ find_package(Threads REQUIRED) find_package(OpenSSL REQUIRED) # Library +#add_library( +# mysql_asio +# SHARED +# src/basic_serialization.cpp +# src/message_serialization.cpp +# src/auth.cpp +#) add_library( - mysql_asio - SHARED - src/basic_serialization.cpp - src/message_serialization.cpp - src/auth.cpp + mysql_asio + INTERFACE ) + target_link_libraries( - mysql_asio - PUBLIC + mysql_asio + INTERFACE Boost::system Threads::Threads OpenSSL::Crypto ) target_include_directories( mysql_asio - PUBLIC + INTERFACE include ) # Main -add_executable( - main - main.cpp -) -target_link_libraries( - main - PRIVATE - mysql_asio -) +#add_executable( +# main +# main.cpp +#) +#target_link_libraries( +# main +# PRIVATE +# mysql_asio +#) # Unit testing set(GTEST_ROOT /opt/gtest) diff --git a/include/basic_serialization.hpp b/include/basic_serialization.hpp deleted file mode 100644 index 8ed36885..00000000 --- a/include/basic_serialization.hpp +++ /dev/null @@ -1,185 +0,0 @@ -#ifndef DESERIALIZATION_H_ -#define DESERIALIZATION_H_ - -#include -#include -#include -#include "basic_types.hpp" - -namespace mysql -{ - -// Utility -void check_size(ReadIterator from, ReadIterator last, std::size_t sz); - -inline std::string_view get_string(ReadIterator from, std::size_t size) -{ - return std::string_view {reinterpret_cast(from), size}; -} - -// Fixed size -template struct get_size { static constexpr std::size_t value = std::size_t(-1); }; -template <> struct get_size { static constexpr std::size_t value = 1; }; -template <> struct get_size { static constexpr std::size_t value = 2; }; -template <> struct get_size { static constexpr std::size_t value = 3; }; -template <> struct get_size { static constexpr std::size_t value = 4; }; -template <> struct get_size { static constexpr std::size_t value = 6; }; -template <> struct get_size { static constexpr std::size_t value = 8; }; -template <> struct get_size { static constexpr std::size_t value = 1; }; -template <> struct get_size { static constexpr std::size_t value = 2; }; -template <> struct get_size { static constexpr std::size_t value = 4; }; -template <> struct get_size { static constexpr std::size_t value = 8; }; -template struct get_size> { static constexpr std::size_t value = size; }; - -template constexpr std::size_t get_size_v = get_size::value; -template constexpr bool is_fixed_size_v = get_size_v != std::size_t(-1); - -template void little_to_native(T& value) { boost::endian::little_to_native_inplace(value); } -template <> inline void little_to_native(int3& value) { boost::endian::little_to_native_inplace(value.value); } -template <> inline void little_to_native(int6& value) { boost::endian::little_to_native_inplace(value.value); } -template void little_to_native(string_fixed&) {} - -// Deserialization functions -template -std::enable_if_t, ReadIterator> -deserialize(ReadIterator from, ReadIterator last, T& output) -{ - check_size(from, last, get_size_v); - memset(&output, 0, sizeof(T)); - memcpy(&output, from, get_size_v); - little_to_native(output); - return from + get_size_v; -} - -ReadIterator deserialize(ReadIterator from, ReadIterator last, int_lenenc& output); -ReadIterator deserialize(ReadIterator from, ReadIterator last, string_null& output); - -inline ReadIterator deserialize(ReadIterator from, ReadIterator last, string_eof& output) -{ - output.value = get_string(from, last-from); - return last; -} - -inline ReadIterator deserialize(ReadIterator from, ReadIterator last, std::string_view& output, std::size_t size) -{ - check_size(from, last, size); - output = get_string(from, size); - return from + size; -} - -inline ReadIterator deserialize(ReadIterator from, ReadIterator last, void* output, std::size_t size) -{ - check_size(from, last, size); - memcpy(output, from, size); - return from + size; -} - -inline ReadIterator deserialize(ReadIterator from, ReadIterator last, string_lenenc& output) -{ - int_lenenc length; - from = deserialize(from, last, length); - from = deserialize(from, last, output.value, length.value); - return from; -} - -template -std::enable_if_t, ReadIterator> -deserialize(ReadIterator from, ReadIterator last, T& to) -{ - std::underlying_type_t value; - ReadIterator res = deserialize(from, last, value); - to = static_cast(value); - return res; -} - -inline ReadIterator deserialize(ReadIterator from, ReadIterator last, nullptr_t&) { return from; } - -template -ReadIterator deserialize(const std::vector& from, T& to) -{ - return deserialize(from.data(), from.data() + from.size(), to); -} - -// SERIALIZATION - -class DynamicBuffer -{ - std::vector buffer_; -public: - DynamicBuffer() = default; - void add(const void* data, std::size_t size) - { - auto current_size = buffer_.size(); - buffer_.resize(current_size+size); - memcpy(buffer_.data() + current_size, data, size); - } - void add(std::uint8_t value) { buffer_.push_back(value); } - const void* data() const { return buffer_.data(); } - void* data() { return buffer_.data(); } - std::size_t size() const { return buffer_.size(); } - const std::vector& get() const { return buffer_; } - void clear() { buffer_.clear(); } -}; - -template void native_to_little(T& value) { boost::endian::native_to_little_inplace(value); } -template <> inline void native_to_little(int3& value) { boost::endian::native_to_little_inplace(value.value); } -template <> inline void native_to_little(int6& value) { boost::endian::native_to_little_inplace(value.value); } - - - -template -std::enable_if_t> -serialize(DynamicBuffer& buffer, T value) -{ - native_to_little(value); - buffer.add(&value, get_size_v); -} - -template <> -inline void serialize(DynamicBuffer& buffer, int1 value) { buffer.add(value); } - -template -void serialize(DynamicBuffer& buffer, const string_fixed& value) -{ - buffer.add(value, sizeof(value)); -} - -void serialize(DynamicBuffer& buffer, int_lenenc value); - -inline void serialize(DynamicBuffer& buffer, const std::string_view& value) -{ - buffer.add(value.data(), value.size()); -} - -inline void serialize(DynamicBuffer& buffer, const string_null& value) -{ - serialize(buffer, value.value); - serialize(buffer, int1(0)); -} - -inline void serialize(DynamicBuffer& buffer, const string_eof& value) -{ - serialize(buffer, value.value); -} - -inline void serialize(DynamicBuffer& buffer, const string_lenenc& value) -{ - serialize(buffer, int_lenenc {value.value.size()}); - serialize(buffer, value.value); -} - -template -std::enable_if_t> -serialize(DynamicBuffer& buffer, T value) -{ - serialize(buffer, static_cast>(value)); -} - -inline void serialize(DynamicBuffer&, nullptr_t) {}; - - -} - - - -#endif /* DESERIALIZATION_H_ */ diff --git a/include/basic_types.hpp b/include/basic_types.hpp deleted file mode 100644 index cc74fc2d..00000000 --- a/include/basic_types.hpp +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef BASIC_TYPES_H_ -#define BASIC_TYPES_H_ - -#include -#include - -namespace mysql -{ - -using ReadIterator = const std::uint8_t*; -using WriteIterator = std::uint8_t*; - -using int1 = std::uint8_t; -using int2 = std::uint16_t; -struct int3 { std::uint32_t value; }; -using int4 = std::uint32_t; -struct int6 { std::uint64_t value; }; -using int8 = std::uint64_t; -struct int_lenenc { std::uint64_t value; }; -template using string_fixed = char[size]; -struct string_null { std::string_view value; }; -struct string_eof { std::string_view value; }; -struct string_lenenc { std::string_view value; }; - -} - - - -#endif /* BASIC_TYPES_H_ */ diff --git a/include/messages.hpp b/include/messages.hpp deleted file mode 100644 index e9892126..00000000 --- a/include/messages.hpp +++ /dev/null @@ -1,268 +0,0 @@ -#ifndef MESSAGES_H_ -#define MESSAGES_H_ - -#include -#include -#include -#include "basic_types.hpp" - -namespace mysql -{ - -// Server/client capabilities -constexpr int4 CLIENT_LONG_PASSWORD = 1; // Use the improved version of Old Password Authentication -constexpr int4 CLIENT_FOUND_ROWS = 2; // Send found rows instead of affected rows in EOF_Packet -constexpr int4 CLIENT_LONG_FLAG = 4; // Get all column flags -constexpr int4 CLIENT_CONNECT_WITH_DB = 8; // Database (schema) name can be specified on connect in Handshake Response Packet -constexpr int4 CLIENT_NO_SCHEMA = 16; // Don't allow database.table.column -constexpr int4 CLIENT_COMPRESS = 32; // Compression protocol supported -constexpr int4 CLIENT_ODBC = 64; // Special handling of ODBC behavior -constexpr int4 CLIENT_LOCAL_FILES = 128; // Can use LOAD DATA LOCAL -constexpr int4 CLIENT_IGNORE_SPACE = 256; // Ignore spaces before '(' -constexpr int4 CLIENT_PROTOCOL_41 = 512; // New 4.1 protocol -constexpr int4 CLIENT_INTERACTIVE = 1024; // This is an interactive client -constexpr int4 CLIENT_SSL = 2048; // Use SSL encryption for the session -constexpr int4 CLIENT_IGNORE_SIGPIPE = 4096; // Client only flag -constexpr int4 CLIENT_TRANSACTIONS = 8192; // Client knows about transactions -constexpr int4 CLIENT_RESERVED = 16384; // DEPRECATED: Old flag for 4.1 protocol -constexpr int4 CLIENT_RESERVED2 = 32768; // DEPRECATED: Old flag for 4.1 authentication \ CLIENT_SECURE_CONNECTION -constexpr int4 CLIENT_MULTI_STATEMENTS = (1UL << 16); // Enable/disable multi-stmt support -constexpr int4 CLIENT_MULTI_RESULTS = (1UL << 17); // Enable/disable multi-results -constexpr int4 CLIENT_PS_MULTI_RESULTS = (1UL << 18); // Multi-results and OUT parameters in PS-protocol -constexpr int4 CLIENT_PLUGIN_AUTH = (1UL << 19); // Client supports plugin authentication -constexpr int4 CLIENT_CONNECT_ATTRS = (1UL << 20); // Client supports connection attributes -constexpr int4 CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = (1UL << 21); // Enable authentication response packet to be larger than 255 bytes -constexpr int4 CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = (1UL << 22); // Don't close the connection for a user account with expired password -constexpr int4 CLIENT_SESSION_TRACK = (1UL << 23); // Capable of handling server state change information -constexpr int4 CLIENT_DEPRECATE_EOF = (1UL << 24); // Client no longer needs EOF_Packet and will use OK_Packet instead -constexpr int4 CLIENT_SSL_VERIFY_SERVER_CERT = (1UL << 30); // Verify server certificate -constexpr int4 CLIENT_OPTIONAL_RESULTSET_METADATA = (1UL << 25); // The client can handle optional metadata information in the resultset -constexpr int4 CLIENT_REMEMBER_OPTIONS = (1UL << 31); // Don't reset the options after an unsuccessful connect - -// Server status flags -constexpr int4 SERVER_STATUS_IN_TRANS = 1; -constexpr int4 SERVER_STATUS_AUTOCOMMIT = 2; -constexpr int4 SERVER_MORE_RESULTS_EXISTS = 8; -constexpr int4 SERVER_QUERY_NO_GOOD_INDEX_USED = 16; -constexpr int4 SERVER_QUERY_NO_INDEX_USED = 32; -constexpr int4 SERVER_STATUS_CURSOR_EXISTS = 64; -constexpr int4 SERVER_STATUS_LAST_ROW_SENT = 128; -constexpr int4 SERVER_STATUS_DB_DROPPED = 256; -constexpr int4 SERVER_STATUS_NO_BACKSLASH_ESCAPES = 512; -constexpr int4 SERVER_STATUS_METADATA_CHANGED = 1024; -constexpr int4 SERVER_QUERY_WAS_SLOW = 2048; -constexpr int4 SERVER_PS_OUT_PARAMS = 4096; -constexpr int4 SERVER_STATUS_IN_TRANS_READONLY = 8192; -constexpr int4 SERVER_SESSION_STATE_CHANGED = (1UL << 14) ; - -enum class CharacterSetLowerByte : int1 -{ - latin1_swedish_ci = 0x08, - utf8_general_ci = 0x21, - binary = 0x3f -}; - -// Packet type constants -constexpr int1 handshake_protocol_version_9 = 9; -constexpr int1 handshake_protocol_version_10 = 10; -constexpr int1 error_packet_header = 0xff; -constexpr int1 ok_packet_header = 0x00; -constexpr int1 eof_packet_header = 0xfe; - -struct PacketHeader -{ - int3 packet_size; - int1 sequence_number; -}; - -struct OkPacket -{ - // header: int<1> header 0x00 or 0xFE the OK packet header - int_lenenc affected_rows; - int_lenenc last_insert_id; - int2 status_flags; // server_status_flags - int2 warnings; - // TODO: CLIENT_SESSION_TRACK - string_eof info; -}; - -struct ErrPacket -{ - // int<1> header 0xFF ERR packet header - int2 error_code; - string_fixed<1> sql_state_marker; - string_fixed<5> sql_state; - string_eof error_message; -}; - -struct Handshake -{ - // int<1> protocol version Always 10 - string_null server_version; - int4 connection_id; - std::string auth_plugin_data; // merge of the two parts - not an actual field - int4 capability_falgs; // merge of the two parts - not an actual field - // string[8] auth-plugin-data-part-1 first 8 bytes of the plugin provided data (scramble) - // int<1> filler 0x00 byte, terminating the first part of a scramble - // int<2> capability_flags_1 The lower 2 bytes of the Capabilities Flags - CharacterSetLowerByte character_set; // default server a_protocol_character_set, only the lower 8-bits - int2 status_flags; // server_status_flags - // int<2> capability_flags_2 The upper 2 bytes of the Capabilities Flags - // int<1> auth_plugin_data_len - // string[10] reserved reserved. All 0s. - // $length auth-plugin-data-part-2 - string_null auth_plugin_name; -}; - -struct HandshakeResponse -{ - int4 client_flag; // capabilities - int4 max_packet_size; - CharacterSetLowerByte character_set; - // string[23] filler filler to the size of the handhshake response packet. All 0s. - string_null username; - string_lenenc auth_response; // we should set CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA - string_null database; // we should set CLIENT_CONNECT_WITH_DB - string_null client_plugin_name; // we should set CLIENT_PLUGIN_AUTH - // TODO: CLIENT_CONNECT_ATTRS -}; - -enum class Command : int1 -{ - COM_QUIT = 1, - COM_INIT_DB = 2, - COM_QUERY = 3, - COM_STATISTICS = 8, - COM_DEBUG = 0x0d, - COM_PING = 0x0e, - COM_CHANGE_USER = 0x11, - COM_BINLOG_DUMP = 0x12, - COM_STMT_PREPARE = 0x16, - COM_STMT_EXECUTE = 0x17, - COM_STMT_SEND_LONG_DATA = 0x18, - COM_STMT_CLOSE = 0x19, - COM_STMT_RESET = 0x1a, - COM_SET_OPTION = 0x1b, - COM_STMT_FETCH = 0x1c, - COM_RESET_CONNECTION = 0x1f -}; - -// Column definitions -enum class FieldType : int1 -{ - DECIMAL = 0x00, - TINY = 0x01, - SHORT = 0x02, - LONG = 0x03, - FLOAT = 0x04, - DOUBLE = 0x05, - NULL_ = 0x06, - TIMESTAMP = 0x07, - LONGLONG = 0x08, - INT24 = 0x09, - DATE = 0x0a, - TIME = 0x0b, - DATETIME = 0x0c, - YEAR = 0x0d, - VARCHAR = 0x0f, - BIT = 0x10, - NEWDECIMAL = 0xf6, - ENUM = 0xf7, - SET = 0xf8, - TINY_BLOB = 0xf9, - MEDIUM_BLOB = 0xfa, - LONG_BLOB = 0xfb, - BLOB = 0xfc, - VAR_STRING = 0xfd, - STRING = 0xfe, - GEOMETRY = 0xff -}; - -struct ColumnDefinition -{ - string_lenenc catalog; // always "def" - string_lenenc schema; - string_lenenc table; // virtual table - string_lenenc org_table; // physical table - string_lenenc name; // virtual column name - string_lenenc org_name; // physical column name - // int length of fixed length fields [0x0c] - int2 character_set; // TODO: enum-erize this - int4 column_length; // maximum length of the field - FieldType type; // type of the column as defined in enum_field_types - int2 flags; // Flags as defined in Column Definition Flags - int1 decimals; // max shown decimal digits. 0x00 for int/static strings; 0x1f for dynamic strings, double, float -}; - -// Prepared statements -constexpr int1 CURSOR_TYPE_NO_CURSOR = 0; -constexpr int1 CURSOR_TYPE_READ_ONLY = 1; -constexpr int1 CURSOR_TYPE_FOR_UPDATE = 2; -constexpr int1 CURSOR_TYPE_SCROLLABLE = 4; - -struct StmtPrepare -{ - string_eof statement; -}; - -struct StmtPrepareResponseHeader -{ - // int1 status: must be 0 - int4 statement_id; - int2 num_columns; - int2 num_params; - // int1 reserved_1: must be 0 - int2 warning_count; // only if (packet_length > 12) - // TODO: int1 metadata_follows when CLIENT_OPTIONAL_RESULTSET_METADATA -}; - -using BinaryValue = std::variant< - std::int8_t, - std::int16_t, - std::int32_t, - std::int64_t, - std::uint8_t, - std::uint16_t, - std::uint32_t, - std::uint64_t, - string_lenenc, - std::nullptr_t // NULL - // TODO: double, float, dates/times ->; - -struct StmtExecute -{ - //int1 message_type: COM_STMT_EXECUTE - int4 statement_id; - int1 flags; - // int4 iteration_count: always 1 - int1 num_params; - int1 new_params_bind_flag; - std::vector param_values; // empty if !new_params_bind_flag -}; - -struct StmtExecuteResponseHeader -{ - int1 num_fields; -}; - -struct StmtFetch -{ - // int1 message_type: COM_STMT_FETCH - int4 statement_id; - int4 rows_to_fetch; -}; - -struct StmtClose -{ - int4 statement_id; -}; - - -} - - - - - -#endif /* MESSAGES_H_ */ diff --git a/include/auth.hpp b/include/mysql/auth.hpp similarity index 100% rename from include/auth.hpp rename to include/mysql/auth.hpp diff --git a/include/mysql/error.hpp b/include/mysql/error.hpp new file mode 100644 index 00000000..527980c4 --- /dev/null +++ b/include/mysql/error.hpp @@ -0,0 +1,19 @@ +#ifndef INCLUDE_ERROR_HPP_ +#define INCLUDE_ERROR_HPP_ + +#include + +namespace mysql +{ + +enum class Error : int +{ + ok = 0, + incomplete_message +}; + +} + +#include "mysql/impl/error_impl.hpp" + +#endif /* INCLUDE_ERROR_HPP_ */ diff --git a/include/mysql/impl/basic_serialization.hpp b/include/mysql/impl/basic_serialization.hpp new file mode 100644 index 00000000..9396cbf3 --- /dev/null +++ b/include/mysql/impl/basic_serialization.hpp @@ -0,0 +1,381 @@ +#ifndef DESERIALIZATION_H_ +#define DESERIALIZATION_H_ + +#include +#include +#include +#include +#include +#include "mysql/impl/basic_types.hpp" +#include "mysql/error.hpp" + +namespace mysql +{ +namespace detail +{ + +class DeserializationContext +{ + ReadIterator first_; + ReadIterator last_; + const std::uint32_t capabilities_; +public: + DeserializationContext(ReadIterator first, ReadIterator last, std::uint32_t capabilities) noexcept: + first_(first), last_(last), capabilities_(capabilities) { assert(last_ >= first_); }; + ReadIterator first() const noexcept { return first_; } + ReadIterator last() const noexcept { return last_; } + void set_first(ReadIterator new_first) noexcept { first_ = new_first; assert(last_ >= first_); } + void advance(std::size_t sz) noexcept { first_ += sz; assert(last_ >= first_); } + std::size_t size() const noexcept { return last_ - first_; } + bool enough_size(std::size_t required_size) const noexcept { return size() >= required_size; } + std::uint32_t capabilities() const noexcept { return capabilities_; } +}; + +class SerializationContext +{ + WriteIterator first_; + const std::uint32_t capabilities_; +public: + SerializationContext(std::uint32_t capabilities, WriteIterator first = nullptr) noexcept: + first_(first), capabilities_(capabilities) {}; + WriteIterator first() const noexcept { return first_; } + void set_first(WriteIterator new_first) noexcept { first_ = new_first; } + void advance(std::size_t size) noexcept { first_ += size; } + std::uint32_t capabilities() const noexcept { return capabilities_; } + void write(const void* buffer, std::size_t size) noexcept { memcpy(first_, buffer, size); advance(size); } + void write(std::uint8_t elm) noexcept { *first_ = elm; ++first_; } +}; + +/** + * Base forms: + * Error deserialize(T& output, DeserializationContext&) noexcept + * void serialize(const T& input, SerializationContext&) noexcept + * std::size_t get_size(const T& input, const SerializationContext&) noexcept + */ + +// Fixed-size types +struct get_value_type_helper +{ + struct no_value_type {}; + + template + static constexpr typename T::value_type get(typename T::value_type*); + + template + static constexpr no_value_type get(...); +}; + +template +struct get_value_type +{ + using type = decltype(get_value_type_helper().get(nullptr)); + using no_value_type = get_value_type_helper::no_value_type; +}; + +template +struct is_fixed_size +{ +private: + using value_type = typename get_value_type::type; +public: + static constexpr bool value = + std::is_integral_v && + std::is_base_of_v, T>; +}; + + +template <> struct is_fixed_size : std::false_type {}; +template struct is_fixed_size>: std::true_type {}; + +template constexpr bool is_fixed_size_v = is_fixed_size::value; + +template +struct get_fixed_size +{ + static_assert(is_fixed_size_v); + static constexpr std::size_t value = sizeof(T::value); +}; + +template <> struct get_fixed_size { static constexpr std::size_t value = 3; }; +template <> struct get_fixed_size { static constexpr std::size_t value = 6; }; +template struct get_fixed_size> { static constexpr std::size_t value = N; }; + +template void little_to_native_inplace(ValueHolder& value) noexcept { boost::endian::little_to_native_inplace(value.value); } +template void little_to_native_inplace(string_fixed&) noexcept {} + +template void native_to_little_inplace(ValueHolder& value) noexcept { boost::endian::native_to_little_inplace(value.value); } +template void native_to_little_inplace(string_fixed&) noexcept {} + + +template +std::enable_if_t, Error> +deserialize(T& output, DeserializationContext& ctx) noexcept +{ + static_assert(std::is_standard_layout_v); + + constexpr auto size = get_fixed_size::value; + if (!ctx.enough_size(size)) + { + return Error::incomplete_message; + } + + memset(&output.value, 0, sizeof(output.value)); + memcpy(&output.value, ctx.first(), size); + little_to_native_inplace(output); + ctx.advance(size); + + return Error::ok; +} + +template +std::enable_if_t> +serialize(T input, SerializationContext& ctx) noexcept +{ + native_to_little_inplace(input); + ctx.write(&input.value, get_fixed_size::value); +} + +template +constexpr std::enable_if_t, std::size_t> +get_size(T input, const SerializationContext&) noexcept +{ + return get_fixed_size::value; +} + +// int_lenenc +inline Error deserialize(int_lenenc& output, DeserializationContext& ctx) noexcept +{ + int1 first_byte; + Error err = deserialize(first_byte, ctx); + if (err != Error::ok) + { + return err; + } + + if (first_byte.value == 0xFC) + { + int2 value; + err = deserialize(value, ctx); + output.value = value.value; + } + else if (first_byte.value == 0xFD) + { + int3 value; + err = deserialize(value, ctx); + output.value = value.value; + } + else if (first_byte.value == 0xFE) + { + int8 value; + err = deserialize(value, ctx); + output.value = value.value; + } + else + { + err = Error::ok; + output.value = first_byte.value; + } + return err; +} +inline void serialize(int_lenenc input, SerializationContext& ctx) noexcept +{ + if (input.value < 251) + { + serialize(int1{static_cast(input.value)}, ctx); + } + else if (input.value < 0x10000) + { + ctx.write(0xfc); + serialize(int2{static_cast(input.value)}, ctx); + } + else if (input.value < 0x1000000) + { + ctx.write(0xfd); + serialize(int3{static_cast(input.value)}, ctx); + } + else + { + ctx.write(0xfe); + serialize(int8{static_cast(input.value)}, ctx); + } +} +inline std::size_t get_size(int_lenenc input, const SerializationContext&) noexcept +{ + if (input.value < 251) return 1; + else if (input.value < 0x10000) return 3; + else if (input.value < 0x1000000) return 4; + else return 9; +} + +// Helper for strings +inline std::string_view get_string(ReadIterator from, std::size_t size) +{ + return std::string_view (reinterpret_cast(from), size); +} + +// string_null +inline Error deserialize(string_null& output, DeserializationContext& ctx) noexcept +{ + ReadIterator string_end = std::find(ctx.first(), ctx.last(), 0); + if (string_end == ctx.last()) + { + return Error::incomplete_message; + } + output.value = get_string(ctx.first(), string_end-ctx.first()); + ctx.set_first(string_end + 1); // skip the null terminator + return Error::ok; +} +inline void serialize(string_null input, SerializationContext& ctx) noexcept +{ + ctx.write(input.value.data(), input.value.size()); + ctx.write(0); // null terminator +} +inline std::size_t get_size(string_null input, const SerializationContext&) noexcept +{ + return input.value.size() + 1; +} + +// string_eof +inline Error deserialize(string_eof& output, DeserializationContext& ctx) noexcept +{ + output.value = get_string(ctx.first(), ctx.last()-ctx.first()); + ctx.set_first(ctx.last()); + return Error::ok; +} +inline void serialize(string_eof input, SerializationContext& ctx) noexcept +{ + ctx.write(input.value.data(), input.value.size()); +} +inline std::size_t get_size(string_eof input, const SerializationContext&) noexcept +{ + return input.value.size(); +} + +// string_lenenc +inline Error deserialize(string_lenenc& output, DeserializationContext& ctx) noexcept +{ + int_lenenc length; + Error err = deserialize(length, ctx); + if (err != Error::ok) + { + return err; + } + if (!ctx.enough_size(length.value)) + { + return Error::incomplete_message; + } + + output.value = get_string(ctx.first(), length.value); + ctx.advance(length.value); + return Error::ok; +} +inline void serialize(string_lenenc input, SerializationContext& ctx) noexcept +{ + int_lenenc length; + length.value = input.value.size(); + serialize(length, ctx); + ctx.write(input.value.data(), input.value.size()); +} +inline std::size_t get_size(string_lenenc input, const SerializationContext& ctx) noexcept +{ + int_lenenc length; + length.value = input.value.size(); + return get_size(length, ctx) + input.value.size(); +} + +// Enums +template >> +Error deserialize(T& output, DeserializationContext& ctx) noexcept +{ + ValueHolder> value; + Error err = deserialize(value, ctx); + if (err != Error::ok) + { + return err; + } + output = static_cast(value.value); + return Error::ok; +} + +template >> +void serialize(T input, SerializationContext& ctx) noexcept +{ + ValueHolder> value {static_cast>(input)}; + serialize(value, ctx); +} + +template >> +std::size_t get_size(T input, const SerializationContext&) noexcept +{ + return get_fixed_size>::value; +} + +// Tuple-like messages +template +Error deserialize_tuple(std::tuple& output, DeserializationContext& ctx) noexcept +{ + if constexpr (index == std::tuple_size_v) + { + return Error::ok; + } + else + { + Error err = deserialize(std::get(output), ctx); + if (err != Error::ok) + { + return err; + } + else + { + return deserialize_tuple(output, ctx); + } + } +} + +template +Error deserialize(std::tuple& output, DeserializationContext& ctx) noexcept +{ + return deserialize_tuple<0>(output, ctx); +} + +template +void serialize_tuple(const std::tuple& input, SerializationContext& ctx) noexcept +{ + if constexpr (index < std::tuple_size_v) + { + serialize(std::get(input), ctx); + serialize_tuple(input, ctx); + } +} + +template +void serialize(const std::tuple& input, SerializationContext& ctx) noexcept +{ + serialize_tuple<0>(input, ctx); +} + +template +std::size_t get_size_tuple(const std::tuple& input, const SerializationContext& ctx) noexcept +{ + if constexpr (index == std::tuple_size_v) + { + return 0; + } + else + { + return get_size_tuple(input, ctx) + + get_size(std::get(input), ctx); + } +} + +template +std::size_t get_size(const std::tuple& input, const SerializationContext& ctx) noexcept +{ + return get_size_tuple<0>(input, ctx); +} + +} +} + + +#endif /* DESERIALIZATION_H_ */ diff --git a/include/mysql/impl/basic_types.hpp b/include/mysql/impl/basic_types.hpp new file mode 100644 index 00000000..b49e8eee --- /dev/null +++ b/include/mysql/impl/basic_types.hpp @@ -0,0 +1,44 @@ +#ifndef BASIC_TYPES_H_ +#define BASIC_TYPES_H_ + +#include +#include +#include + +namespace mysql +{ +namespace detail +{ + +using ReadIterator = const std::uint8_t*; +using WriteIterator = std::uint8_t*; + +template +struct ValueHolder +{ + using value_type = T; + + value_type value; +}; + +struct int1 : ValueHolder {}; +struct int2 : ValueHolder {}; +struct int3 : ValueHolder {}; +struct int4 : ValueHolder {}; +struct int6 : ValueHolder {}; +struct int8 : ValueHolder {}; +struct int1_signed : ValueHolder {}; +struct int2_signed : ValueHolder {}; +struct int4_signed : ValueHolder {}; +struct int8_signed : ValueHolder {}; +struct int_lenenc : ValueHolder {}; +template struct string_fixed : ValueHolder> {}; +struct string_null : ValueHolder {}; +struct string_eof : ValueHolder {}; +struct string_lenenc : ValueHolder {}; + +} +} + + +#endif /* BASIC_TYPES_H_ */ diff --git a/include/mysql/impl/constants.hpp b/include/mysql/impl/constants.hpp new file mode 100644 index 00000000..028dca27 --- /dev/null +++ b/include/mysql/impl/constants.hpp @@ -0,0 +1,133 @@ +#ifndef INCLUDE_MYSQL_IMPL_CONSTANTS_HPP_ +#define INCLUDE_MYSQL_IMPL_CONSTANTS_HPP_ + +#include "mysql/impl/basic_types.hpp" + +namespace mysql +{ +namespace detail +{ + +// Server/client capabilities +constexpr std::uint32_t CLIENT_LONG_PASSWORD = 1; // Use the improved version of Old Password Authentication +constexpr std::uint32_t CLIENT_FOUND_ROWS = 2; // Send found rows instead of affected rows in EOF_Packet +constexpr std::uint32_t CLIENT_LONG_FLAG = 4; // Get all column flags +constexpr std::uint32_t CLIENT_CONNECT_WITH_DB = 8; // Database (schema) name can be specified on connect in Handshake Response Packet +constexpr std::uint32_t CLIENT_NO_SCHEMA = 16; // Don't allow database.table.column +constexpr std::uint32_t CLIENT_COMPRESS = 32; // Compression protocol supported +constexpr std::uint32_t CLIENT_ODBC = 64; // Special handling of ODBC behavior +constexpr std::uint32_t CLIENT_LOCAL_FILES = 128; // Can use LOAD DATA LOCAL +constexpr std::uint32_t CLIENT_IGNORE_SPACE = 256; // Ignore spaces before '(' +constexpr std::uint32_t CLIENT_PROTOCOL_41 = 512; // New 4.1 protocol +constexpr std::uint32_t CLIENT_INTERACTIVE = 1024; // This is an interactive client +constexpr std::uint32_t CLIENT_SSL = 2048; // Use SSL encryption for the session +constexpr std::uint32_t CLIENT_IGNORE_SIGPIPE = 4096; // Client only flag +constexpr std::uint32_t CLIENT_TRANSACTIONS = 8192; // Client knows about transactions +constexpr std::uint32_t CLIENT_RESERVED = 16384; // DEPRECATED: Old flag for 4.1 protocol +constexpr std::uint32_t CLIENT_RESERVED2 = 32768; // DEPRECATED: Old flag for 4.1 authentication \ CLIENT_SECURE_CONNECTION +constexpr std::uint32_t CLIENT_MULTI_STATEMENTS = (1UL << 16); // Enable/disable multi-stmt support +constexpr std::uint32_t CLIENT_MULTI_RESULTS = (1UL << 17); // Enable/disable multi-results +constexpr std::uint32_t CLIENT_PS_MULTI_RESULTS = (1UL << 18); // Multi-results and OUT parameters in PS-protocol +constexpr std::uint32_t CLIENT_PLUGIN_AUTH = (1UL << 19); // Client supports plugin authentication +constexpr std::uint32_t CLIENT_CONNECT_ATTRS = (1UL << 20); // Client supports connection attributes +constexpr std::uint32_t CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = (1UL << 21); // Enable authentication response packet to be larger than 255 bytes +constexpr std::uint32_t CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = (1UL << 22); // Don't close the connection for a user account with expired password +constexpr std::uint32_t CLIENT_SESSION_TRACK = (1UL << 23); // Capable of handling server state change information +constexpr std::uint32_t CLIENT_DEPRECATE_EOF = (1UL << 24); // Client no longer needs EOF_Packet and will use OK_Packet instead +constexpr std::uint32_t CLIENT_SSL_VERIFY_SERVER_CERT = (1UL << 30); // Verify server certificate +constexpr std::uint32_t CLIENT_OPTIONAL_RESULTSET_METADATA = (1UL << 25); // The client can handle optional metadata information in the resultset +constexpr std::uint32_t CLIENT_REMEMBER_OPTIONS = (1UL << 31); // Don't reset the options after an unsuccessful connect + +// Server status flags +constexpr std::uint32_t SERVER_STATUS_IN_TRANS = 1; +constexpr std::uint32_t SERVER_STATUS_AUTOCOMMIT = 2; +constexpr std::uint32_t SERVER_MORE_RESULTS_EXISTS = 8; +constexpr std::uint32_t SERVER_QUERY_NO_GOOD_INDEX_USED = 16; +constexpr std::uint32_t SERVER_QUERY_NO_INDEX_USED = 32; +constexpr std::uint32_t SERVER_STATUS_CURSOR_EXISTS = 64; +constexpr std::uint32_t SERVER_STATUS_LAST_ROW_SENT = 128; +constexpr std::uint32_t SERVER_STATUS_DB_DROPPED = 256; +constexpr std::uint32_t SERVER_STATUS_NO_BACKSLASH_ESCAPES = 512; +constexpr std::uint32_t SERVER_STATUS_METADATA_CHANGED = 1024; +constexpr std::uint32_t SERVER_QUERY_WAS_SLOW = 2048; +constexpr std::uint32_t SERVER_PS_OUT_PARAMS = 4096; +constexpr std::uint32_t SERVER_STATUS_IN_TRANS_READONLY = 8192; +constexpr std::uint32_t SERVER_SESSION_STATE_CHANGED = (1UL << 14) ; + +enum class CharacterSetLowerByte : std::uint8_t +{ + latin1_swedish_ci = 0x08, + utf8_general_ci = 0x21, + binary = 0x3f +}; + +// Packet type constants +constexpr std::uint8_t handshake_protocol_version_9 = 9; +constexpr std::uint8_t handshake_protocol_version_10 = 10; +constexpr std::uint8_t error_packet_header = 0xff; +constexpr std::uint8_t ok_packet_header = 0x00; +constexpr std::uint8_t eof_packet_header = 0xfe; + +enum class Command : std::uint8_t +{ + COM_QUIT = 1, + COM_INIT_DB = 2, + COM_QUERY = 3, + COM_STATISTICS = 8, + COM_DEBUG = 0x0d, + COM_PING = 0x0e, + COM_CHANGE_USER = 0x11, + COM_BINLOG_DUMP = 0x12, + COM_STMT_PREPARE = 0x16, + COM_STMT_EXECUTE = 0x17, + COM_STMT_SEND_LONG_DATA = 0x18, + COM_STMT_CLOSE = 0x19, + COM_STMT_RESET = 0x1a, + COM_SET_OPTION = 0x1b, + COM_STMT_FETCH = 0x1c, + COM_RESET_CONNECTION = 0x1f +}; + +// Column definitions +enum class FieldType : std::uint8_t +{ + DECIMAL = 0x00, + TINY = 0x01, + SHORT = 0x02, + LONG = 0x03, + FLOAT = 0x04, + DOUBLE = 0x05, + NULL_ = 0x06, + TIMESTAMP = 0x07, + LONGLONG = 0x08, + INT24 = 0x09, + DATE = 0x0a, + TIME = 0x0b, + DATETIME = 0x0c, + YEAR = 0x0d, + VARCHAR = 0x0f, + BIT = 0x10, + NEWDECIMAL = 0xf6, + ENUM = 0xf7, + SET = 0xf8, + TINY_BLOB = 0xf9, + MEDIUM_BLOB = 0xfa, + LONG_BLOB = 0xfb, + BLOB = 0xfc, + VAR_STRING = 0xfd, + STRING = 0xfe, + GEOMETRY = 0xff +}; + +// Prepared statements +constexpr std::uint8_t CURSOR_TYPE_NO_CURSOR = 0; +constexpr std::uint8_t CURSOR_TYPE_READ_ONLY = 1; +constexpr std::uint8_t CURSOR_TYPE_FOR_UPDATE = 2; +constexpr std::uint8_t CURSOR_TYPE_SCROLLABLE = 4; + +} +} + + + +#endif /* INCLUDE_MYSQL_IMPL_CONSTANTS_HPP_ */ diff --git a/include/mysql/impl/error_impl.hpp b/include/mysql/impl/error_impl.hpp new file mode 100644 index 00000000..be2f505f --- /dev/null +++ b/include/mysql/impl/error_impl.hpp @@ -0,0 +1,66 @@ +#ifndef INCLUDE_IMPL_ERROR_IMPL_HPP_ +#define INCLUDE_IMPL_ERROR_IMPL_HPP_ + +#include + +namespace boost +{ +namespace system +{ + +template <> +struct is_error_code_enum +{ + static constexpr bool value = true; +}; + +} // system +} // boost + + +namespace mysql +{ +namespace detail +{ + +inline const char* error_to_string(Error error) noexcept +{ + switch (error) + { + case Error::ok: return "no error"; + case Error::incomplete_message: return "The message read was incomplete (not enough bytes to fully decode it)"; + default: return ""; + } +} + +class MysqlErrorCategory : public boost::system::error_category +{ +public: + const char* name() const noexcept final override { return "mysql"; } + std::string message(int ev) const final override + { + return error_to_string(static_cast(ev)); + } +}; + +inline const boost::system::error_category& get_mysql_error_category() +{ + static MysqlErrorCategory res; + return res; +} + + +inline boost::system::error_code make_error_code(Error error) +{ + return boost::system::error_code( + static_cast(error), get_mysql_error_category() + ); +} + +} +} + + + + +#endif /* INCLUDE_IMPL_ERROR_IMPL_HPP_ */ diff --git a/include/mysql/impl/messages.hpp b/include/mysql/impl/messages.hpp new file mode 100644 index 00000000..4102d181 --- /dev/null +++ b/include/mysql/impl/messages.hpp @@ -0,0 +1,170 @@ +#ifndef MESSAGES_H_ +#define MESSAGES_H_ + +#include +#include +#include +#include +#include "mysql/impl/basic_types.hpp" +#include "mysql/impl/constants.hpp" + +namespace mysql +{ +namespace detail +{ + +// Fields +namespace fields +{ +struct packet_size : int3 {}; +struct sequence_number: int1 {}; +struct message_header : int1 {}; +struct error_code : int2 {}; +struct sql_state_marker : string_fixed<1> {}; +struct sql_state : string_fixed<5> {}; +struct error_message : string_eof {}; +} + +using packet_header = std::tuple< + fields::packet_size, + fields::sequence_number +>; + +struct OkPacket +{ + // header: int<1> header 0x00 or 0xFE the OK packet header + int_lenenc affected_rows; + int_lenenc last_insert_id; + int2 status_flags; // server_status_flags + int2 warnings; + // TODO: CLIENT_SESSION_TRACK + string_eof info; +}; + +using err_packet = std::tuple< + fields::message_header, // int<1> header 0xFF ERR packet header + fields::error_code, + fields::sql_state_marker, + fields::sql_state, + fields::error_message +>; + + +struct Handshake +{ + // int<1> protocol version Always 10 + string_null server_version; + int4 connection_id; + std::string auth_plugin_data; // merge of the two parts - not an actual field + int4 capability_falgs; // merge of the two parts - not an actual field + // string[8] auth-plugin-data-part-1 first 8 bytes of the plugin provided data (scramble) + // int<1> filler 0x00 byte, terminating the first part of a scramble + // int<2> capability_flags_1 The lower 2 bytes of the Capabilities Flags + CharacterSetLowerByte character_set; // default server a_protocol_character_set, only the lower 8-bits + int2 status_flags; // server_status_flags + // int<2> capability_flags_2 The upper 2 bytes of the Capabilities Flags + // int<1> auth_plugin_data_len + // string[10] reserved reserved. All 0s. + // $length auth-plugin-data-part-2 + string_null auth_plugin_name; +}; + +struct HandshakeResponse +{ + int4 client_flag; // capabilities + int4 max_packet_size; + CharacterSetLowerByte character_set; + // string[23] filler filler to the size of the handhshake response packet. All 0s. + string_null username; + string_lenenc auth_response; // we should set CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + string_null database; // we should set CLIENT_CONNECT_WITH_DB + string_null client_plugin_name; // we should set CLIENT_PLUGIN_AUTH + // TODO: CLIENT_CONNECT_ATTRS +}; + + + +struct ColumnDefinition +{ + string_lenenc catalog; // always "def" + string_lenenc schema; + string_lenenc table; // virtual table + string_lenenc org_table; // physical table + string_lenenc name; // virtual column name + string_lenenc org_name; // physical column name + // int length of fixed length fields [0x0c] + int2 character_set; // TODO: enum-erize this + int4 column_length; // maximum length of the field + FieldType type; // type of the column as defined in enum_field_types + int2 flags; // Flags as defined in Column Definition Flags + int1 decimals; // max shown decimal digits. 0x00 for int/static strings; 0x1f for dynamic strings, double, float +}; + + + +struct StmtPrepare +{ + string_eof statement; +}; + +struct StmtPrepareResponseHeader +{ + // int1 status: must be 0 + int4 statement_id; + int2 num_columns; + int2 num_params; + // int1 reserved_1: must be 0 + int2 warning_count; // only if (packet_length > 12) + // TODO: int1 metadata_follows when CLIENT_OPTIONAL_RESULTSET_METADATA +}; + +using BinaryValue = std::variant< + std::int8_t, + std::int16_t, + std::int32_t, + std::int64_t, + std::uint8_t, + std::uint16_t, + std::uint32_t, + std::uint64_t, + string_lenenc, + std::nullptr_t // NULL + // TODO: double, float, dates/times +>; + +struct StmtExecute +{ + //int1 message_type: COM_STMT_EXECUTE + int4 statement_id; + int1 flags; + // int4 iteration_count: always 1 + int1 num_params; + int1 new_params_bind_flag; + std::vector param_values; // empty if !new_params_bind_flag +}; + +struct StmtExecuteResponseHeader +{ + int1 num_fields; +}; + +struct StmtFetch +{ + // int1 message_type: COM_STMT_FETCH + int4 statement_id; + int4 rows_to_fetch; +}; + +struct StmtClose +{ + int4 statement_id; +}; + + +} // detail +} // mysql + + + + +#endif /* MESSAGES_H_ */ diff --git a/include/impl/mysql_stream_impl.hpp b/include/mysql/impl/mysql_stream_impl.hpp similarity index 100% rename from include/impl/mysql_stream_impl.hpp rename to include/mysql/impl/mysql_stream_impl.hpp diff --git a/include/impl/prepared_statement_impl.hpp b/include/mysql/impl/prepared_statement_impl.hpp similarity index 100% rename from include/impl/prepared_statement_impl.hpp rename to include/mysql/impl/prepared_statement_impl.hpp diff --git a/include/message_serialization.hpp b/include/mysql/message_serialization.hpp similarity index 100% rename from include/message_serialization.hpp rename to include/mysql/message_serialization.hpp diff --git a/include/mysql_stream.hpp b/include/mysql/mysql_stream.hpp similarity index 100% rename from include/mysql_stream.hpp rename to include/mysql/mysql_stream.hpp diff --git a/include/null_bitmap.hpp b/include/mysql/null_bitmap.hpp similarity index 100% rename from include/null_bitmap.hpp rename to include/mysql/null_bitmap.hpp diff --git a/include/prepared_statement.hpp b/include/mysql/prepared_statement.hpp similarity index 100% rename from include/prepared_statement.hpp rename to include/mysql/prepared_statement.hpp diff --git a/main.cpp b/main.cpp index b9482bc3..a83c3b6f 100644 --- a/main.cpp +++ b/main.cpp @@ -1,8 +1,8 @@ #include #include #include -#include "mysql_stream.hpp" -#include "prepared_statement.hpp" +#include "mysql/mysql_stream.hpp" +#include "mysql/prepared_statement.hpp" using namespace std; using namespace boost::asio; diff --git a/src/basic_serialization.cpp b/src/basic_serialization.cpp deleted file mode 100644 index c035e148..00000000 --- a/src/basic_serialization.cpp +++ /dev/null @@ -1,80 +0,0 @@ -/* - * deserialization.cpp - * - * Created on: Jun 30, 2019 - * Author: ruben - */ - -#include -#include -#include - -using namespace std; - -void mysql::check_size(ReadIterator from, ReadIterator last, std::size_t sz) -{ - if ((last - from) < sz) - throw std::out_of_range {"Overflow"}; -} - -mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, int_lenenc& output) -{ - std::uint8_t first_byte; - from = deserialize(from, last, first_byte); - if (first_byte == 0xFC) - { - int2 value; - from = deserialize(from, last, value); - output.value = value; - } - else if (first_byte == 0xFD) - { - int3 value; - from = deserialize(from, last, value); - output.value = value.value; - } - else if (first_byte == 0xFE) - { - from = deserialize(from, last, output.value); - } - else - { - output.value = first_byte; - } - return from; -} - -mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, string_null& output) -{ - ReadIterator string_end = std::find(from, last, 0); - if (string_end == last) - throw std::out_of_range {"Overflow (null-terminated string)"}; - output.value = get_string(from, string_end-from); - return string_end + 1; // skip the null terminator -} - -void mysql::serialize(DynamicBuffer& buffer, int_lenenc value) -{ - if (value.value < 251) - { - serialize(buffer, static_cast(value.value)); - } - else if (value.value < 0x10000) - { - serialize(buffer, int1(0xfc)); - serialize(buffer, static_cast(value.value)); - } - else if (value.value < 0x1000000) - { - serialize(buffer, int1(0xfd)); - serialize(buffer, int3 {static_cast(value.value)}); - } - else - { - serialize(buffer, int1(0xfe)); - serialize(buffer, static_cast(value.value)); - } -} - - - diff --git a/test/basic_serialization.cpp b/test/basic_serialization.cpp index e1ec1c5e..c30a00e0 100644 --- a/test/basic_serialization.cpp +++ b/test/basic_serialization.cpp @@ -5,103 +5,143 @@ * Author: ruben */ -#include "basic_serialization.hpp" +#include "mysql/impl/basic_serialization.hpp" #include #include +using namespace testing; using namespace std; using namespace mysql; +using namespace mysql::detail; + +namespace +{ // Fixed size integers - -template constexpr std::size_t int_size = sizeof(T); +template constexpr std::size_t int_size = sizeof(T::value); template <> constexpr std::size_t int_size = 3; template <> constexpr std::size_t int_size = 6; -template constexpr T expected_int_value; -template <> constexpr int1 expected_int_value { 0xff }; -template <> constexpr int2 expected_int_value { 0xfeff }; -template <> constexpr int3 expected_int_value { 0xfdfeff }; -template <> constexpr int4 expected_int_value { 0xfcfdfeff }; -template <> constexpr int6 expected_int_value { 0xfafbfcfdfeff }; -template <> constexpr int8 expected_int_value { 0xf8f9fafbfcfdfeff }; +template constexpr T expected_int_value(); +template <> constexpr int1 expected_int_value() { return int1{0xff}; }; +template <> constexpr int2 expected_int_value() { return int2{0xfeff}; }; +template <> constexpr int3 expected_int_value() { return int3{0xfdfeff}; }; +template <> constexpr int4 expected_int_value() { return int4{0xfcfdfeff}; }; +template <> constexpr int6 expected_int_value() { return int6{0xfafbfcfdfeff}; }; +template <> constexpr int8 expected_int_value() { return int8{0xf8f9fafbfcfdfeff}; }; +template <> constexpr int1_signed expected_int_value() { return int1_signed{-1}; }; +template <> constexpr int2_signed expected_int_value() { return int2_signed{-0x101}; }; +template <> constexpr int4_signed expected_int_value() { return int4_signed{-0x3020101}; }; +template <> constexpr int8_signed expected_int_value() { return int8_signed{-0x0706050403020101}; }; -template constexpr auto get_int_underlying_value(T from) { return from; } -constexpr uint32_t get_int_underlying_value(int3 from) { return from.value; } -constexpr uint64_t get_int_underlying_value(int6 from) { return from.value; } +// TODO: signed integers template struct DeserializeFixedSizeInt : public ::testing::Test { uint8_t buffer [16]; + T value; + DeserializeFixedSizeInt(): buffer { 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8, 0xf7 } - {}; + { + memset(&value, 1, sizeof(value)); // catch unititialized memory errors + }; }; -using FixedSizeIntTypes = ::testing::Types; +using FixedSizeIntTypes = ::testing::Types< + int1, + int2, + int3, + int4, + int6, + int8, + int1_signed, + int2_signed, + int4_signed, + int8_signed +>; TYPED_TEST_SUITE(DeserializeFixedSizeInt, FixedSizeIntTypes); TYPED_TEST(DeserializeFixedSizeInt, ExactSize_GetsValueIncrementsIterator) { - TypeParam value; - auto res = deserialize(this->buffer, this->buffer + int_size, value); - EXPECT_EQ(res, this->buffer+int_size); - EXPECT_EQ(get_int_underlying_value(value), get_int_underlying_value(expected_int_value)); + DeserializationContext ctx (this->buffer, this->buffer + int_size, 0); + + auto err = deserialize(this->value, ctx); + + EXPECT_EQ(ctx.first(), this->buffer+int_size); + EXPECT_EQ(this->value.value, expected_int_value().value); + EXPECT_EQ(err, Error::ok); } TYPED_TEST(DeserializeFixedSizeInt, ExtraSize_GetsValueIncrementsIterator) { - TypeParam value; - auto res = deserialize(this->buffer, this->buffer + int_size + 1, value); - EXPECT_EQ(res, this->buffer+int_size); - EXPECT_EQ(get_int_underlying_value(value), get_int_underlying_value(expected_int_value)); + DeserializationContext ctx (this->buffer, this->buffer + 1 + int_size, 0); + + auto err = deserialize(this->value, ctx); + + EXPECT_EQ(ctx.first(), this->buffer+int_size); + EXPECT_EQ(this->value.value, expected_int_value().value); + EXPECT_EQ(err, Error::ok); } -TYPED_TEST(DeserializeFixedSizeInt, Overflow_ThrowsOutOfRange) +TYPED_TEST(DeserializeFixedSizeInt, Overflow_ReturnsError) { - TypeParam value; - EXPECT_THROW(deserialize(this->buffer, this->buffer + int_size - 1, value), out_of_range); + DeserializationContext ctx (this->buffer, this->buffer - 1 + int_size, 0); + auto err = deserialize(this->value, ctx); + EXPECT_EQ(err, Error::incomplete_message); } // Length-encoded integer -struct LengthEncodedIntTestParams +struct DeserializeLengthEncodedIntParams { uint8_t first_byte; uint64_t expected; size_t buffer_size; }; -struct DeserializeLengthEncodedInt : public ::testing::TestWithParam {}; +struct DeserializeLengthEncodedInt : public ::testing::TestWithParam +{ + uint8_t buffer [10]; + int_lenenc value; + int_lenenc initial_value; + + DeserializeLengthEncodedInt(): + buffer { GetParam().first_byte, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8 } + { + memset(&value, 1, sizeof(value)); + initial_value = value; + } +}; TEST_P(DeserializeLengthEncodedInt, ExactSize_GetsValueIncrementsIterator) { - uint8_t buffer [10] = { GetParam().first_byte, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8 }; - int_lenenc value; - auto it = deserialize(buffer, buffer + GetParam().buffer_size , value); - EXPECT_EQ(it, buffer + GetParam().buffer_size); + DeserializationContext ctx (buffer, buffer + GetParam().buffer_size, 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(ctx.first(), buffer + GetParam().buffer_size); EXPECT_EQ(value.value, GetParam().expected); + EXPECT_EQ(err, Error::ok); } TEST_P(DeserializeLengthEncodedInt, ExtraSize_GetsValueIncrementsIterator) { - uint8_t buffer [10] = { GetParam().first_byte, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8, 0xf7 }; - int_lenenc value; - auto it = deserialize(buffer, end(buffer), value); - EXPECT_EQ(it, buffer + GetParam().buffer_size); + DeserializationContext ctx (buffer, end(buffer), 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(ctx.first(), buffer + GetParam().buffer_size); EXPECT_EQ(value.value, GetParam().expected); + EXPECT_EQ(err, Error::ok); } -TEST_P(DeserializeLengthEncodedInt, Overflow_ThrowsOutOfRange) +TEST_P(DeserializeLengthEncodedInt, Overflow_ReturnsError) { - uint8_t buffer [10] = { GetParam().first_byte, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8, 0xf7 }; - int_lenenc value; - EXPECT_THROW(deserialize(buffer, buffer + GetParam().buffer_size - 1, value), out_of_range); + DeserializationContext ctx (buffer, buffer + GetParam().buffer_size - 1, 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(err, Error::incomplete_message); } INSTANTIATE_TEST_SUITE_P(Default, DeserializeLengthEncodedInt, ::testing::Values( - LengthEncodedIntTestParams{0x0a, 0x0a, 1}, - LengthEncodedIntTestParams{0xfc, 0xfeff, 3}, - LengthEncodedIntTestParams{0xfd, 0xfdfeff, 4}, - LengthEncodedIntTestParams{0xfe, 0xf8f9fafbfcfdfeff, 9} + DeserializeLengthEncodedIntParams{0x0a, 0x0a, 1}, + DeserializeLengthEncodedIntParams{0xfc, 0xfeff, 3}, + DeserializeLengthEncodedIntParams{0xfd, 0xfdfeff, 4}, + DeserializeLengthEncodedIntParams{0xfe, 0xf8f9fafbfcfdfeff, 9} ), [](const auto& v) { return "first_byte_" + to_string(v.param.first_byte); }); // Fixed size string @@ -109,29 +149,42 @@ struct DeserializeFixedSizeString : public testing::Test { uint8_t buffer [6] { 'a', 'b', '\0', 'd', 'e', 'f' }; string_fixed<5> value; + + DeserializeFixedSizeString() + { + memset(value.value.data(), 1, value.value.size()); + } + + string_view value_as_view() const { return string_view(value.value.data(), value.value.size()); } }; TEST_F(DeserializeFixedSizeString, ExactSize_CopiesValueIncrementsIterator) { - ReadIterator res = deserialize(begin(buffer), begin(buffer) + 5, value); - EXPECT_EQ(value, string_view {"ab\0de"}); - EXPECT_EQ(res, begin(buffer) + 5); + DeserializationContext ctx (begin(buffer), begin(buffer) + 5, 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(ctx.first(), begin(buffer) + 5); + EXPECT_EQ(value_as_view(), string_view("ab\0de", 5)); + EXPECT_EQ(err, Error::ok); } TEST_F(DeserializeFixedSizeString, ExtraSize_CopiesValueIncrementsIterator) { - ReadIterator res = deserialize(begin(buffer), end(buffer), value); - EXPECT_EQ(value, string_view {"ab\0de"}); - EXPECT_EQ(res, begin(buffer) + 5); + DeserializationContext ctx (begin(buffer), end(buffer), 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(ctx.first(), begin(buffer) + 5); + EXPECT_EQ(value_as_view(), string_view("ab\0de", 5)); + EXPECT_EQ(err, Error::ok); } -TEST_F(DeserializeFixedSizeString, Overflow_ThrowsOutOfRange) +TEST_F(DeserializeFixedSizeString, Overflow_ReturnsError) { - EXPECT_THROW(deserialize(begin(buffer), begin(buffer) + 4, value), out_of_range); + DeserializationContext ctx (begin(buffer), begin(buffer) + 4, 0); + auto err = deserialize(value, ctx); + EXPECT_EQ(err, Error::incomplete_message); } // Null-terminated string -struct DeserializeNullTerminatedString : public testing::Test +/*struct DeserializeNullTerminatedString : public testing::Test { uint8_t buffer [4] { 'a', 'b', '\0', 'd' }; string_null value; @@ -264,4 +317,7 @@ TEST_F(DeserializeEnum, ExtraSize_GetsValueIncrementsIterator) TEST_F(DeserializeEnum, Overflow_ThrowsOutOfRange) { EXPECT_THROW(deserialize(begin(buffer), begin(buffer) + 1, value), out_of_range); -} +}*/ + + +} // anon namespace