diff --git a/CMakeLists.txt b/CMakeLists.txt index 34039e90..8e8960cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ add_library( src/message_serialization.cpp src/mysql_stream.cpp src/auth.cpp + src/prepared_statement.cpp ) target_link_libraries( mysql_asio diff --git a/include/basic_serialization.hpp b/include/basic_serialization.hpp index 565a6816..d0168ec2 100644 --- a/include/basic_serialization.hpp +++ b/include/basic_serialization.hpp @@ -88,6 +88,12 @@ deserialize(ReadIterator from, ReadIterator last, T& to) return res; } +template +ReadIterator deserialize(const std::vector& from, T& to) +{ + return deserialize(from.data(), from.data() + from.size(), to); +} + // SERIALIZATION class DynamicBuffer @@ -105,6 +111,7 @@ public: const void* data() const { return buffer_.data(); } void* data() { return buffer_.data(); } std::size_t size() const { return buffer_.size(); } + const std::vector& get() const { return buffer_; } }; template void native_to_little(T& value) { boost::endian::native_to_little_inplace(value); } diff --git a/include/message_serialization.hpp b/include/message_serialization.hpp index c42eb0d9..a0dfbcec 100644 --- a/include/message_serialization.hpp +++ b/include/message_serialization.hpp @@ -11,6 +11,7 @@ namespace mysql // general ReadIterator deserialize(ReadIterator from, ReadIterator last, PacketHeader& output); +void serialize(DynamicBuffer& buffer, const PacketHeader& value); ReadIterator deserialize(ReadIterator from, ReadIterator last, OkPacket& output); ReadIterator deserialize(ReadIterator from, ReadIterator last, ErrPacket& output); diff --git a/include/mysql_stream.hpp b/include/mysql_stream.hpp index a8090e55..76b8ffee 100644 --- a/include/mysql_stream.hpp +++ b/include/mysql_stream.hpp @@ -3,6 +3,7 @@ #include #include +#include #include "basic_types.hpp" #include "messages.hpp" #include "message_serialization.hpp" @@ -10,6 +11,7 @@ namespace mysql { + struct HandshakeParams { CharacterSetLowerByte character_set; @@ -18,51 +20,7 @@ struct HandshakeParams std::string_view database; }; -class AnyPacketReader -{ - PacketHeader header_; - int1 message_type_; - std::unique_ptr buffer_; -public: - const PacketHeader& header() const { return header_; } - int1 message_type() const { return message_type_; } - ReadIterator begin() const { return buffer_.get() + 1; } // past message type, if any - ReadIterator end() const { return buffer_.get() + header_.packet_size.value; } - bool is_ok() const { return message_type() == mysql::ok_packet_header; } - bool is_error() const { return message_type() == error_packet_header; } - bool is_eof() const { return message_type() == eof_packet_header; } - void check_error() const; - - template - void deserialize_message(MessageType& output) const - { - ReadIterator last = mysql::deserialize(begin(), end(), output); - if (last != end()) - throw std::out_of_range { "Additional data after packet end" }; - check_error(); - } - - void read(boost::asio::ip::tcp::socket& stream); -}; - -class AnyPacketWriter -{ - DynamicBuffer buffer_; -public: - AnyPacketWriter(int1 sequence_number); - DynamicBuffer& buffer() { return buffer_; } - const DynamicBuffer& buffer() const { return buffer_; } - void set_length(); - - template - void serialize_message(const MessageType& value) - { - mysql::serialize(buffer_, value); - set_length(); - } - - void write(boost::asio::ip::tcp::socket& stream); -}; +int1 get_message_type(const std::vector& buffer, bool check_err=true); class MysqlStream { @@ -70,12 +28,14 @@ class MysqlStream int1 sequence_number_ {0}; void process_sequence_number(int1 got); - void read_packet(AnyPacketReader& output); public: MysqlStream(boost::asio::io_context& ctx): next_layer_ {ctx} {}; auto& next_layer() { return next_layer_; } void handshake(const HandshakeParams&); + void read(std::vector& buffer); + void write(const std::vector& buffer); + void reset_sequence_number() { sequence_number_ = 0; } }; } diff --git a/include/prepared_statement.hpp b/include/prepared_statement.hpp new file mode 100644 index 00000000..a44f8b1e --- /dev/null +++ b/include/prepared_statement.hpp @@ -0,0 +1,48 @@ +#ifndef INCLUDE_PREPARED_STATEMENT_HPP_ +#define INCLUDE_PREPARED_STATEMENT_HPP_ + +#include "messages.hpp" +#include "mysql_stream.hpp" + +namespace mysql +{ + +struct ParamDefinition +{ + std::vector packet; + ColumnDefinition value; +}; + +class PreparedStatement +{ + MysqlStream* stream_; + int4 statement_id_; + std::vector params_; + std::vector columns_; +public: + PreparedStatement(MysqlStream& stream, int4 statement_id, + std::vector&& params, std::vector&& columns): + stream_ {&stream}, statement_id_ {statement_id}, params_{std::move(params)}, + columns_ {std::move(columns)} {}; + PreparedStatement(const PreparedStatement&) = delete; + PreparedStatement(PreparedStatement&&) = default; + const PreparedStatement& operator=(const PreparedStatement&) = delete; + PreparedStatement& operator=(PreparedStatement&&) = default; + ~PreparedStatement() = default; + + int4 statement_id() const { return statement_id_; } + const std::vector& params() const { return params_; } + const std::vector& columns() const { return columns_; } + // execute(Something&&... params): StatementResponse + // execute() // uses already-bound params + // close(Connection) + // Destructor should try to auto-close + + static PreparedStatement prepare(MysqlStream& stream, std::string_view query); +}; + +} + + + +#endif /* INCLUDE_PREPARED_STATEMENT_HPP_ */ diff --git a/main.cpp b/main.cpp index 3dc609e4..8a8badd0 100644 --- a/main.cpp +++ b/main.cpp @@ -2,6 +2,7 @@ #include #include #include "mysql_stream.hpp" +#include "prepared_statement.hpp" using namespace std; using namespace boost::asio; @@ -42,5 +43,10 @@ int main() "mysql" }); + // Prepare a statement + + mysql::PreparedStatement stmt { mysql::PreparedStatement::prepare( + stream, "SELECT host FROM user WHERE user = ?") }; + } diff --git a/src/message_serialization.cpp b/src/message_serialization.cpp index afbe57b9..edf36b91 100644 --- a/src/message_serialization.cpp +++ b/src/message_serialization.cpp @@ -19,6 +19,12 @@ mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, Pac return from; } +void mysql::serialize(DynamicBuffer& buffer, const PacketHeader& value) +{ + serialize(buffer, value.packet_size); + serialize(buffer, value.sequence_number); +} + mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, OkPacket& output) { // TODO: is packet header to be deserialized as part of this? @@ -115,6 +121,7 @@ mysql::ReadIterator mysql::deserialize(ReadIterator from, ReadIterator last, Stm from = deserialize(from, last, output.num_columns); from = deserialize(from, last, output.num_params); from = deserialize(from, last, reserved); + // TODO: warning_count appears to be optional but we are always requiring it from = deserialize(from, last, output.warning_count); return from; } diff --git a/src/mysql_stream.cpp b/src/mysql_stream.cpp index 8fe23ab4..e7c3905f 100644 --- a/src/mysql_stream.cpp +++ b/src/mysql_stream.cpp @@ -50,51 +50,22 @@ constexpr mysql::int4 BASIC_CAPABILITIES_FLAGS = mysql::CLIENT_DEPRECATE_EOF | mysql::CLIENT_CONNECT_WITH_DB; -void mysql::AnyPacketReader::read(boost::asio::ip::tcp::socket& stream) +mysql::int1 mysql::get_message_type(const std::vector& buffer, bool check_err) { - // Connection phase - uint8_t header_buffer [4]; - boost::asio::read(stream, boost::asio::buffer(header_buffer)); - deserialize(header_buffer, std::end(header_buffer), header_); - - // Read the rest of the packet - buffer_.reset(new uint8_t[header_.packet_size.value]); - boost::asio::read(stream, boost::asio::mutable_buffer{buffer_.get(), header_.packet_size.value}); - - deserialize(buffer_.get(), end(), message_type_); -} - -void mysql::AnyPacketReader::check_error() const -{ - if (is_error()) + mysql::int1 res; + ReadIterator current = mysql::deserialize(buffer, res); + if (check_err && res == error_packet_header) { ErrPacket error_packet; - deserialize_message(error_packet); + deserialize(current, buffer.data() + buffer.size(), error_packet); std::ostringstream ss; ss << "SQL error: " << error_packet.error_message.value << " (" << error_packet.error_code << ")"; throw std::runtime_error { ss.str() }; } + return res; } -mysql::AnyPacketWriter::AnyPacketWriter(int1 sequence_number) -{ - std::uint8_t initial_content [] = {0, 0, 0, sequence_number}; - buffer_.add(initial_content, sizeof(initial_content)); -} - -void mysql::AnyPacketWriter::set_length() -{ - assert(buffer_.size() <= 0xffffff); // TODO: handle the case where this does not hold - int3 packet_length { static_cast(buffer_.size() - 4) }; - boost::endian::native_to_little_inplace(packet_length.value); - memcpy(buffer_.data(), &packet_length.value, 3); -} - -void mysql::AnyPacketWriter::write(boost::asio::ip::tcp::socket& stream) -{ - boost::asio::write(stream, boost::asio::buffer(buffer_.data(), buffer_.size())); -} void mysql::MysqlStream::process_sequence_number(int1 got) { @@ -102,28 +73,61 @@ void mysql::MysqlStream::process_sequence_number(int1 got) throw std::runtime_error { "Mismatched sequence number" }; } -void mysql::MysqlStream::read_packet(AnyPacketReader& packet) +void mysql::MysqlStream::read(std::vector& buffer) { - packet.read(next_layer_); - process_sequence_number(packet.header().sequence_number); + PacketHeader header; + uint8_t header_buffer [4]; + std::size_t current_size = 0; + do + { + boost::asio::read(next_layer_, boost::asio::buffer(header_buffer)); + deserialize(std::begin(header_buffer), std::end(header_buffer), header); + process_sequence_number(header.sequence_number); + buffer.resize(current_size + header.packet_size.value); + boost::asio::read(next_layer_, boost::asio::buffer(buffer.data() + current_size, header.packet_size.value)); + current_size += header.packet_size.value; + } while (header.packet_size.value == 0xffffff); +} + +void mysql::MysqlStream::write(const std::vector& buffer) +{ + PacketHeader header; + DynamicBuffer header_buffer; // TODO: change to a plain uint8_t when we generalize this + std::size_t current_size = 0; + while (current_size < buffer.size()) + { + auto size_to_write = static_cast(std::min( + std::vector::size_type(0xffffff), + buffer.size() - current_size + )); + header.packet_size.value = size_to_write; + header.sequence_number = sequence_number_++; + serialize(header_buffer, header); + boost::asio::write(next_layer_, boost::asio::buffer(header_buffer.data(), header_buffer.size())); + boost::asio::write(next_layer_, boost::asio::buffer(buffer.data() + current_size, size_to_write)); + current_size += size_to_write; + } } void mysql::MysqlStream::handshake(const HandshakeParams& params) { - AnyPacketReader reader; - read_packet(reader); + std::vector read_buffer; + DynamicBuffer write_buffer; - if (reader.message_type() != handshake_protocol_version_10) + // Read handshake + read(read_buffer); + auto msg_type = get_message_type(read_buffer); + if (msg_type != handshake_protocol_version_10) { - const char* reason = reader.message_type() == handshake_protocol_version_9 ? + const char* reason = msg_type == handshake_protocol_version_9 ? "Unsupported protocol version 9" : "Unknown message type"; throw std::runtime_error {reason}; } + mysql::Handshake handshake; + deserialize(read_buffer.data()+1, read_buffer.data() + read_buffer.size(), handshake); // Process the handshake - mysql::Handshake handshake; - reader.deserialize_message(handshake); check_capabilities(handshake.capability_falgs); check_authentication_method(handshake); cout << handshake << "\n\n"; @@ -142,16 +146,16 @@ void mysql::MysqlStream::handshake(const HandshakeParams& params) cout << handshake_response << "\n\n"; // Serialize and send - AnyPacketWriter writer { sequence_number_++ }; - writer.serialize_message(handshake_response); - writer.write(next_layer_); + serialize(write_buffer, handshake_response); + write(write_buffer.get()); // TODO: support auth mismatch // TODO: support SSL // Read the OK/ERR - read_packet(reader); - if (!reader.is_ok() && !reader.is_eof()) + read(read_buffer); + msg_type = get_message_type(read_buffer); + if (msg_type != ok_packet_header && msg_type != eof_packet_header) { throw std::runtime_error { "Unknown message type" }; } diff --git a/src/prepared_statement.cpp b/src/prepared_statement.cpp new file mode 100644 index 00000000..3e0eb24e --- /dev/null +++ b/src/prepared_statement.cpp @@ -0,0 +1,40 @@ + +#include "prepared_statement.hpp" +#include "message_serialization.hpp" + +using namespace std; + +static void read_param(mysql::MysqlStream& stream, mysql::ParamDefinition& output) +{ + stream.read(output.packet); + mysql::deserialize(output.packet, output.value); +} + +mysql::PreparedStatement mysql::PreparedStatement::prepare(MysqlStream& stream, std::string_view query) +{ + // Write the prepare request + StmtPrepare packet {{query}}; + DynamicBuffer write_buffer; + serialize(write_buffer, packet); + stream.reset_sequence_number(); + stream.write(write_buffer.get()); + + // Get the prepare response + std::vector read_buffer; + stream.read(read_buffer); + int1 status = get_message_type(read_buffer); + if (status != 0) + throw std::runtime_error {"Error preparing statement" + std::string{query}}; + StmtPrepareResponseHeader response; + deserialize(read_buffer.data() + 1, read_buffer.data() + read_buffer.size(), response); + + // Read the parameters and columns if any + std::vector params (response.num_params); + for (int2 i = 0; i < response.num_params; ++i) + read_param(stream, params[i]); + std::vector columns (response.num_columns); + for (int2 i = 0; i < response.num_columns; ++i) + read_param(stream, columns[i]); + + return PreparedStatement {stream, response.statement_id, move(params), move(columns)}; +}