diff --git a/CMakeLists.txt b/CMakeLists.txt index 605edeaa..74672eb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,8 @@ find_package(OpenSSL REQUIRED) add_library( mysql_asio SHARED - src/deserialization.cpp + src/basic_serialization.cpp + src/message_serialization.cpp src/auth.cpp ) target_link_libraries( @@ -45,7 +46,7 @@ find_package(GTest REQUIRED) enable_testing() add_executable( unittests - test/deserialization.cpp + test/basic_serialization.cpp ) target_link_libraries( unittests diff --git a/include/deserialization.hpp b/include/basic_serialization.hpp similarity index 84% rename from include/deserialization.hpp rename to include/basic_serialization.hpp index 474daddc..1cb46d00 100644 --- a/include/deserialization.hpp +++ b/include/basic_serialization.hpp @@ -5,7 +5,6 @@ #include #include #include "basic_types.hpp" -#include "messages.hpp" namespace mysql { @@ -32,8 +31,8 @@ template constexpr std::size_t get_size_v = get_size::value; template constexpr bool is_fixed_size_v = get_size_v != std::size_t(-1); template void little_to_native(T& value) { boost::endian::little_to_native_inplace(value); } -template <> void little_to_native(int3& value) { boost::endian::little_to_native_inplace(value.value); } -template <> void little_to_native(int6& value) { boost::endian::little_to_native_inplace(value.value); } +template <> inline void little_to_native(int3& value) { boost::endian::little_to_native_inplace(value.value); } +template <> inline void little_to_native(int6& value) { boost::endian::little_to_native_inplace(value.value); } template void little_to_native(string_fixed&) {} // Deserialization functions @@ -113,8 +112,8 @@ public: }; template void native_to_little(T& value) { boost::endian::native_to_little_inplace(value); } -template <> void native_to_little(int3& value) { boost::endian::native_to_little_inplace(value.value); } -template <> void native_to_little(int6& value) { boost::endian::native_to_little_inplace(value.value); } +template <> inline void native_to_little(int3& value) { boost::endian::native_to_little_inplace(value.value); } +template <> inline void native_to_little(int6& value) { boost::endian::native_to_little_inplace(value.value); } @@ -160,12 +159,7 @@ inline void serialize(DynamicBuffer& buffer, const string_lenenc& value) } -// Packet serialization and deserialization -ReadIterator deserialize(ReadIterator from, ReadIterator last, PacketHeader& output); -ReadIterator deserialize(ReadIterator from, ReadIterator last, OkPacket& output); -ReadIterator deserialize(ReadIterator from, ReadIterator last, ErrPacket& output); -ReadIterator deserialize(ReadIterator from, ReadIterator last, Handshake& output); -void serialize(DynamicBuffer& buffer, const HandshakeResponse& value); + } diff --git a/include/message_serialization.hpp b/include/message_serialization.hpp new file mode 100644 index 00000000..da1f6894 --- /dev/null +++ b/include/message_serialization.hpp @@ -0,0 +1,28 @@ +#ifndef INCLUDE_MESSAGE_SERIALIZATION_HPP_ +#define INCLUDE_MESSAGE_SERIALIZATION_HPP_ + +#include "basic_types.hpp" +#include "messages.hpp" +#include "basic_serialization.hpp" + +namespace mysql +{ + +// general +ReadIterator deserialize(ReadIterator from, ReadIterator last, PacketHeader& output); +ReadIterator deserialize(ReadIterator from, ReadIterator last, OkPacket& output); +ReadIterator deserialize(ReadIterator from, ReadIterator last, ErrPacket& output); + +// Connection phase +ReadIterator deserialize(ReadIterator from, ReadIterator last, Handshake& output); +void serialize(DynamicBuffer& buffer, const HandshakeResponse& value); + +// Prepared statements +void serialize(DynamicBuffer& buffer, const StmtPreparePacket& value); +ReadIterator deserialize(ReadIterator from, ReadIterator last, StmtPrepareResponsePacket& output); + +} + + + +#endif /* INCLUDE_MESSAGE_SERIALIZATION_HPP_ */ diff --git a/include/messages.hpp b/include/messages.hpp index cd6b2c6d..a7673c8f 100644 --- a/include/messages.hpp +++ b/include/messages.hpp @@ -2,6 +2,7 @@ #define MESSAGES_H_ #include +#include #include "basic_types.hpp" namespace mysql @@ -124,6 +125,93 @@ struct HandshakeResponse // TODO: CLIENT_CONNECT_ATTRS }; +enum class Command +{ + COM_QUIT = 1, + COM_INIT_DB = 2, + COM_QUERY = 3, + COM_STATISTICS = 8, + COM_DEBUG = 0x0d, + COM_PING = 0x0e, + COM_CHANGE_USER = 0x11, + COM_BINLOG_DUMP = 0x12, + COM_STMT_PREPARE = 0x16, + COM_STMT_EXECUTE = 0x17, + COM_STMT_SEND_LONG_DATA = 0x18, + COM_STMT_CLOSE = 0x19, + COM_STMT_RESET = 0x1a, + COM_SET_OPTION = 0x1b, + COM_STMT_FETCH = 0x1c, + COM_RESET_CONNECTION = 0x1f +}; + +// Column definitions +enum class FieldType : int1 +{ + DECIMAL = 0x00, + TINY = 0x01, + SHORT = 0x02, + LONG = 0x03, + FLOAT = 0x04, + DOUBLE = 0x05, + NULL_ = 0x06, + TIMESTAMP = 0x07, + LONGLONG = 0x08, + INT24 = 0x09, + DATE = 0x0a, + TIME = 0x0b, + DATETIME = 0x0c, + YEAR = 0x0d, + VARCHAR = 0x0f, + BIT = 0x10, + NEWDECIMAL = 0xf6, + ENUM = 0xf7, + SET = 0xf8, + TINY_BLOB = 0xf9, + MEDIUM_BLOB = 0xfa, + LONG_BLOB = 0xfb, + BLOB = 0xfc, + VAR_STRING = 0xfd, + STRING = 0xfe, + GEOMETRY = 0xff +}; + +struct ColumnDefinition +{ + string_lenenc catalog; // always "def" + string_lenenc schema; + string_lenenc table; // virtual table + string_lenenc org_table; // physical table + string_lenenc name; // virtual column name + string_lenenc org_name; // physical column name + // int length of fixed length fields [0x0c] + int2 character_set; // TODO: enum-erize this + int4 column_length; // maximum length of the field + FieldType type; // type of the column as defined in enum_field_types + int2 flags; // Flags as defined in Column Definition Flags + int1 decimals; // max shown decimal digits. 0x00 for int/static strings; 0x1f for dynamic strings, double, float +}; + +// Prepared statements +struct StmtPreparePacket +{ + string_eof statement; +}; + +struct StmtPrepareResponsePacket +{ + // int1 status: must be 0 + int4 statement_id; + int2 num_columns; + int2 num_params; + // int1 reserved_1: must be 0 + int2 warning_count; // only if (packet_length > 12) + // TODO: int1 metadata_follows when CLIENT_OPTIONAL_RESULTSET_METADATA + std::vector params; +}; + + + } diff --git a/main.cpp b/main.cpp index ab09c58f..dbb91da9 100644 --- a/main.cpp +++ b/main.cpp @@ -11,8 +11,8 @@ #include #include #include "basic_types.hpp" -#include "deserialization.hpp" #include "auth.hpp" +#include "message_serialization.hpp" using namespace std; using namespace boost::asio; diff --git a/src/basic_serialization.cpp b/src/basic_serialization.cpp new file mode 100644 index 00000000..c035e148 --- /dev/null +++ b/src/basic_serialization.cpp @@ -0,0 +1,80 @@ +/* + * deserialization.cpp + * + * Created on: Jun 30, 2019 + * Author: ruben + */ + +#include +#include +#include + +using namespace std; + +void mysql::check_size(ReadIterator from, ReadIterator last, std::size_t sz) +{ + if ((last - from) < sz) + throw std::out_of_range {"Overflow"}; +} + +mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, int_lenenc& output) +{ + std::uint8_t first_byte; + from = deserialize(from, last, first_byte); + if (first_byte == 0xFC) + { + int2 value; + from = deserialize(from, last, value); + output.value = value; + } + else if (first_byte == 0xFD) + { + int3 value; + from = deserialize(from, last, value); + output.value = value.value; + } + else if (first_byte == 0xFE) + { + from = deserialize(from, last, output.value); + } + else + { + output.value = first_byte; + } + return from; +} + +mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, string_null& output) +{ + ReadIterator string_end = std::find(from, last, 0); + if (string_end == last) + throw std::out_of_range {"Overflow (null-terminated string)"}; + output.value = get_string(from, string_end-from); + return string_end + 1; // skip the null terminator +} + +void mysql::serialize(DynamicBuffer& buffer, int_lenenc value) +{ + if (value.value < 251) + { + serialize(buffer, static_cast(value.value)); + } + else if (value.value < 0x10000) + { + serialize(buffer, int1(0xfc)); + serialize(buffer, static_cast(value.value)); + } + else if (value.value < 0x1000000) + { + serialize(buffer, int1(0xfd)); + serialize(buffer, int3 {static_cast(value.value)}); + } + else + { + serialize(buffer, int1(0xfe)); + serialize(buffer, static_cast(value.value)); + } +} + + + diff --git a/src/deserialization.cpp b/src/message_serialization.cpp similarity index 64% rename from src/deserialization.cpp rename to src/message_serialization.cpp index e69344cf..d14f78ad 100644 --- a/src/deserialization.cpp +++ b/src/message_serialization.cpp @@ -1,81 +1,14 @@ /* - * deserialization.cpp + * message_serialization.cpp * - * Created on: Jun 30, 2019 + * Created on: Jul 7, 2019 * Author: ruben */ -#include "deserialization.hpp" -#include -#include +#include "message_serialization.hpp" using namespace std; -void mysql::check_size(ReadIterator from, ReadIterator last, std::size_t sz) -{ - if ((last - from) < sz) - throw std::out_of_range {"Overflow"}; -} - -mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, int_lenenc& output) -{ - std::uint8_t first_byte; - from = deserialize(from, last, first_byte); - if (first_byte == 0xFC) - { - int2 value; - from = deserialize(from, last, value); - output.value = value; - } - else if (first_byte == 0xFD) - { - int3 value; - from = deserialize(from, last, value); - output.value = value.value; - } - else if (first_byte == 0xFE) - { - from = deserialize(from, last, output.value); - } - else - { - output.value = first_byte; - } - return from; -} - -mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, string_null& output) -{ - ReadIterator string_end = std::find(from, last, 0); - if (string_end == last) - throw std::out_of_range {"Overflow (null-terminated string)"}; - output.value = get_string(from, string_end-from); - return string_end + 1; // skip the null terminator -} - -void mysql::serialize(DynamicBuffer& buffer, int_lenenc value) -{ - if (value.value < 251) - { - serialize(buffer, static_cast(value.value)); - } - else if (value.value < 0x10000) - { - serialize(buffer, int1(0xfc)); - serialize(buffer, static_cast(value.value)); - } - else if (value.value < 0x1000000) - { - serialize(buffer, int1(0xfd)); - serialize(buffer, int3 {static_cast(value.value)}); - } - else - { - serialize(buffer, int1(0xfe)); - serialize(buffer, static_cast(value.value)); - } -} - // Packet serialization and deserialization mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, PacketHeader& output) { @@ -146,3 +79,4 @@ void mysql::serialize(DynamicBuffer& buffer, const HandshakeResponse& value) } + diff --git a/test/deserialization.cpp b/test/basic_serialization.cpp similarity index 99% rename from test/deserialization.cpp rename to test/basic_serialization.cpp index e0ef343c..e1ec1c5e 100644 --- a/test/deserialization.cpp +++ b/test/basic_serialization.cpp @@ -5,7 +5,7 @@ * Author: ruben */ -#include "deserialization.hpp" +#include "basic_serialization.hpp" #include #include