mirror of
https://github.com/boostorg/mysql.git
synced 2026-02-20 02:42:26 +00:00
Made MysqlStream, PreparedStatement be templates
Made MysqlStream::write a template accepting any ConstBufferSequence
This commit is contained in:
192
include/impl/mysql_stream_impl.hpp
Normal file
192
include/impl/mysql_stream_impl.hpp
Normal file
@@ -0,0 +1,192 @@
|
||||
#ifndef INCLUDE_IMPL_MYSQL_STREAM_IMPL_HPP_
|
||||
#define INCLUDE_IMPL_MYSQL_STREAM_IMPL_HPP_
|
||||
|
||||
#include "message_serialization.hpp"
|
||||
#include "auth.hpp"
|
||||
#include <boost/asio/read.hpp>
|
||||
#include <boost/asio/write.hpp>
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
namespace mysql
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <typename T1, typename... Flags>
|
||||
constexpr bool all_set(T1 input, Flags... flags)
|
||||
{
|
||||
return ((input & flags) && ...);
|
||||
}
|
||||
|
||||
inline void check_capabilities(mysql::int4 server_capabilities)
|
||||
{
|
||||
bool ok = all_set(server_capabilities,
|
||||
mysql::CLIENT_CONNECT_WITH_DB,
|
||||
mysql::CLIENT_PROTOCOL_41,
|
||||
mysql::CLIENT_PLUGIN_AUTH,
|
||||
mysql::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
|
||||
mysql::CLIENT_DEPRECATE_EOF
|
||||
);
|
||||
if (!ok)
|
||||
throw std::runtime_error { "Missing server capabilities, server not supported" };
|
||||
}
|
||||
|
||||
inline void check_authentication_method(const mysql::Handshake& handshake)
|
||||
{
|
||||
if (handshake.auth_plugin_name.value != "mysql_native_password")
|
||||
throw std::runtime_error { "Unsupported authentication method" }; // TODO: we should be responding with our method
|
||||
if (handshake.auth_plugin_data.size() != mysql::mysql_native_password::challenge_length)
|
||||
throw std::runtime_error { "Bad authentication data length" };
|
||||
}
|
||||
|
||||
constexpr mysql::int4 BASIC_CAPABILITIES_FLAGS =
|
||||
mysql::CLIENT_PROTOCOL_41 |
|
||||
mysql::CLIENT_PLUGIN_AUTH |
|
||||
mysql::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA |
|
||||
mysql::CLIENT_DEPRECATE_EOF |
|
||||
mysql::CLIENT_CONNECT_WITH_DB;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline mysql::int1 mysql::get_message_type(const std::vector<std::uint8_t>& buffer, bool check_err)
|
||||
{
|
||||
mysql::int1 res;
|
||||
ReadIterator current = mysql::deserialize(buffer, res);
|
||||
if (check_err && res == error_packet_header)
|
||||
{
|
||||
ErrPacket error_packet;
|
||||
deserialize(current, buffer.data() + buffer.size(), error_packet);
|
||||
std::ostringstream ss;
|
||||
ss << "SQL error: " << error_packet.error_message.value
|
||||
<< " (" << error_packet.error_code << ")";
|
||||
throw std::runtime_error { ss.str() };
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::MysqlStream<AsyncStream>::process_sequence_number(int1 got)
|
||||
{
|
||||
if (got != sequence_number_++)
|
||||
throw std::runtime_error { "Mismatched sequence number" };
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::MysqlStream<AsyncStream>::read(std::vector<std::uint8_t>& buffer)
|
||||
{
|
||||
PacketHeader header;
|
||||
uint8_t header_buffer [4];
|
||||
std::size_t current_size = 0;
|
||||
do
|
||||
{
|
||||
boost::asio::read(next_layer_, boost::asio::buffer(header_buffer));
|
||||
deserialize(std::begin(header_buffer), std::end(header_buffer), header);
|
||||
process_sequence_number(header.sequence_number);
|
||||
buffer.resize(current_size + header.packet_size.value);
|
||||
boost::asio::read(next_layer_, boost::asio::buffer(buffer.data() + current_size, header.packet_size.value));
|
||||
current_size += header.packet_size.value;
|
||||
} while (header.packet_size.value == 0xffffff);
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::MysqlStream<AsyncStream>::write(const std::vector<std::uint8_t>& buffer)
|
||||
{
|
||||
write(boost::asio::buffer(buffer.data(), buffer.size()));
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
template <typename ConstBufferSequence>
|
||||
void mysql::MysqlStream<AsyncStream>::write(ConstBufferSequence&& buffers)
|
||||
{
|
||||
PacketHeader header;
|
||||
DynamicBuffer header_buffer; // TODO: change to a plain uint8_t when we generalize serialization
|
||||
std::size_t current_size = 0;
|
||||
constexpr std::size_t MAX_PACKET_SIZE = 0xffffff;
|
||||
|
||||
auto first = boost::asio::buffer_sequence_begin(buffers);
|
||||
auto last = boost::asio::buffer_sequence_end(buffers);
|
||||
|
||||
// TODO: we can do this better - for a multi-element
|
||||
// buffer sequence, we could merge some of the data into
|
||||
// a single packet
|
||||
for (auto it = first; it != last; ++it)
|
||||
{
|
||||
current_size = 0;
|
||||
auto bufsize = it->size();
|
||||
while (current_size < bufsize)
|
||||
{
|
||||
auto size_to_write = static_cast<std::uint32_t>(std::min(
|
||||
MAX_PACKET_SIZE,
|
||||
bufsize - current_size
|
||||
));
|
||||
header.packet_size.value = size_to_write;
|
||||
header.sequence_number = sequence_number_++;
|
||||
header_buffer.clear();
|
||||
serialize(header_buffer, header);
|
||||
// TODO: we could use a buffer sequence to write these two
|
||||
boost::asio::write(next_layer_, boost::asio::buffer(header_buffer.data(), header_buffer.size()));
|
||||
boost::asio::write(next_layer_, boost::asio::buffer(*it + current_size, size_to_write));
|
||||
current_size += size_to_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::MysqlStream<AsyncStream>::handshake(const HandshakeParams& params)
|
||||
{
|
||||
std::vector<std::uint8_t> read_buffer;
|
||||
DynamicBuffer write_buffer;
|
||||
|
||||
// Read handshake
|
||||
read(read_buffer);
|
||||
auto msg_type = get_message_type(read_buffer);
|
||||
if (msg_type != handshake_protocol_version_10)
|
||||
{
|
||||
const char* reason = msg_type == handshake_protocol_version_9 ?
|
||||
"Unsupported protocol version 9" :
|
||||
"Unknown message type";
|
||||
throw std::runtime_error {reason};
|
||||
}
|
||||
mysql::Handshake handshake;
|
||||
deserialize(read_buffer.data()+1, read_buffer.data() + read_buffer.size(), handshake);
|
||||
|
||||
// Process the handshake
|
||||
detail::check_capabilities(handshake.capability_falgs);
|
||||
detail::check_authentication_method(handshake);
|
||||
std::cout << handshake << "\n\n";
|
||||
|
||||
// Response
|
||||
mysql::HandshakeResponse handshake_response;
|
||||
mysql_native_password::response_buffer auth_response;
|
||||
mysql_native_password::compute_auth_string(params.password, handshake.auth_plugin_data.data(), auth_response);
|
||||
handshake_response.client_flag = detail::BASIC_CAPABILITIES_FLAGS;
|
||||
handshake_response.max_packet_size = 0xffff;
|
||||
handshake_response.character_set = params.character_set;
|
||||
handshake_response.username.value = params.username;
|
||||
handshake_response.auth_response.value = std::string_view {(const char*)auth_response, sizeof(auth_response)};
|
||||
handshake_response.client_plugin_name.value = "mysql_native_password";
|
||||
handshake_response.database.value = params.database;
|
||||
std::cout << handshake_response << "\n\n";
|
||||
|
||||
// Serialize and send
|
||||
serialize(write_buffer, handshake_response);
|
||||
write(write_buffer.get());
|
||||
|
||||
// TODO: support auth mismatch
|
||||
// TODO: support SSL
|
||||
|
||||
// Read the OK/ERR
|
||||
read(read_buffer);
|
||||
msg_type = get_message_type(read_buffer);
|
||||
if (msg_type != ok_packet_header && msg_type != eof_packet_header)
|
||||
{
|
||||
throw std::runtime_error { "Unknown message type" };
|
||||
}
|
||||
std::cout << "Connected to server\n";
|
||||
}
|
||||
|
||||
#endif
|
||||
327
include/impl/prepared_statement_impl.hpp
Normal file
327
include/impl/prepared_statement_impl.hpp
Normal file
@@ -0,0 +1,327 @@
|
||||
#ifndef INCLUDE_IMPL_PREPARED_STATEMENT_IMPL_HPP_
|
||||
#define INCLUDE_IMPL_PREPARED_STATEMENT_IMPL_HPP_
|
||||
|
||||
#include "message_serialization.hpp"
|
||||
#include "mysql_stream.hpp"
|
||||
#include "null_bitmap.hpp"
|
||||
#include <cassert>
|
||||
|
||||
namespace mysql
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
|
||||
inline void fill_execute_msg_impl(std::vector<BinaryValue>::iterator) {}
|
||||
|
||||
template <typename Param0, typename... Params>
|
||||
void fill_execute_msg_impl(
|
||||
std::vector<BinaryValue>::iterator param_output,
|
||||
Param0&& param,
|
||||
Params&&... tail
|
||||
)
|
||||
{
|
||||
*param_output = std::forward<Param0>(param);
|
||||
fill_execute_msg_impl(std::next(param_output), std::forward<Params>(tail)...);
|
||||
}
|
||||
|
||||
|
||||
template <typename... Args>
|
||||
void fill_execute_msg(StmtExecute& output, std::size_t num_params, Args&&... args)
|
||||
{
|
||||
if (sizeof...(args) != num_params)
|
||||
{
|
||||
throw std::out_of_range {"Wrong number of parameters passed to prepared statement"};
|
||||
}
|
||||
output.num_params = static_cast<int1>(num_params);
|
||||
output.new_params_bind_flag = 1;
|
||||
output.param_values.resize(num_params);
|
||||
fill_execute_msg_impl(output.param_values.begin(), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
std::vector<ParamDefinition> read_fields(
|
||||
MysqlStream<AsyncStream>& stream,
|
||||
std::size_t quantity
|
||||
)
|
||||
{
|
||||
std::vector<ParamDefinition> res (quantity);
|
||||
for (auto& elm: res)
|
||||
{
|
||||
stream.read(elm.packet);
|
||||
deserialize(elm.packet, elm.value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BinaryValue deserialize_field_impl(
|
||||
ReadIterator& first,
|
||||
ReadIterator last
|
||||
)
|
||||
{
|
||||
T value;
|
||||
first = deserialize(first, last, value);
|
||||
return BinaryValue {value};
|
||||
}
|
||||
|
||||
inline BinaryValue not_implemented()
|
||||
{
|
||||
throw std::runtime_error{"Not implemented"};
|
||||
}
|
||||
|
||||
inline BinaryValue deserialize_field(
|
||||
FieldType type,
|
||||
ReadIterator& first,
|
||||
ReadIterator last
|
||||
)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case FieldType::DECIMAL:
|
||||
case FieldType::VARCHAR:
|
||||
case FieldType::BIT:
|
||||
case FieldType::NEWDECIMAL:
|
||||
case FieldType::ENUM:
|
||||
case FieldType::SET:
|
||||
case FieldType::TINY_BLOB:
|
||||
case FieldType::MEDIUM_BLOB:
|
||||
case FieldType::LONG_BLOB:
|
||||
case FieldType::BLOB:
|
||||
case FieldType::VAR_STRING:
|
||||
case FieldType::STRING:
|
||||
case FieldType::GEOMETRY:
|
||||
return deserialize_field_impl<string_lenenc>(first, last);
|
||||
case FieldType::TINY:
|
||||
return deserialize_field_impl<int1>(first, last);
|
||||
case FieldType::SHORT:
|
||||
return deserialize_field_impl<int2>(first, last);
|
||||
case FieldType::INT24:
|
||||
case FieldType::LONG:
|
||||
return deserialize_field_impl<int4>(first, last);
|
||||
case FieldType::LONGLONG:
|
||||
return deserialize_field_impl<int8>(first, last);
|
||||
case FieldType::FLOAT:
|
||||
return not_implemented();
|
||||
case FieldType::DOUBLE:
|
||||
return not_implemented();
|
||||
case FieldType::NULL_:
|
||||
return deserialize_field_impl<nullptr_t>(first, last);
|
||||
case FieldType::TIMESTAMP:
|
||||
case FieldType::DATE:
|
||||
case FieldType::TIME:
|
||||
case FieldType::DATETIME:
|
||||
case FieldType::YEAR:
|
||||
default:
|
||||
return not_implemented();
|
||||
}
|
||||
}
|
||||
|
||||
inline void deserialize_binary_row(
|
||||
const std::vector<std::uint8_t>& packet,
|
||||
const std::vector<ParamDefinition>& fields,
|
||||
std::vector<BinaryValue>& output
|
||||
)
|
||||
{
|
||||
output.clear();
|
||||
output.reserve(fields.size());
|
||||
ResultsetRowNullBitmapTraits traits {fields.size()};
|
||||
ReadIterator null_bitmap_first = packet.data() + 1; // skip header
|
||||
ReadIterator current = null_bitmap_first + traits.byte_count();
|
||||
ReadIterator last = packet.data() + packet.size();
|
||||
|
||||
for (std::size_t i = 0; i < fields.size(); ++i)
|
||||
{
|
||||
if (traits.is_null(null_bitmap_first, i))
|
||||
output.emplace_back(nullptr);
|
||||
else
|
||||
output.push_back(deserialize_field(fields[i].value.type, current, last));
|
||||
}
|
||||
if (current != last)
|
||||
{
|
||||
throw std::out_of_range {"Leftover data after binary row"};
|
||||
}
|
||||
}
|
||||
|
||||
} // detail
|
||||
} // mysql
|
||||
|
||||
|
||||
template <typename AsyncStream>
|
||||
mysql::PreparedStatement<AsyncStream>::PreparedStatement(
|
||||
MysqlStream<AsyncStream>& stream,
|
||||
int4 statement_id,
|
||||
std::vector<ParamDefinition>&& params,
|
||||
std::vector<ParamDefinition>&& columns
|
||||
) :
|
||||
stream_ {&stream},
|
||||
statement_id_ {statement_id},
|
||||
params_ {std::move(params)},
|
||||
columns_ {std::move(columns)}
|
||||
{
|
||||
};
|
||||
|
||||
template <typename AsyncStream>
|
||||
mysql::PreparedStatement<AsyncStream> mysql::PreparedStatement<AsyncStream>::prepare(
|
||||
MysqlStream<AsyncStream>& stream,
|
||||
std::string_view query
|
||||
)
|
||||
{
|
||||
// Write the prepare request
|
||||
StmtPrepare packet {{query}};
|
||||
DynamicBuffer write_buffer;
|
||||
serialize(write_buffer, packet);
|
||||
stream.reset_sequence_number();
|
||||
stream.write(write_buffer.get());
|
||||
|
||||
// Get the prepare response
|
||||
std::vector<std::uint8_t> read_buffer;
|
||||
stream.read(read_buffer);
|
||||
int1 status = get_message_type(read_buffer);
|
||||
if (status != 0)
|
||||
throw std::runtime_error {"Error preparing statement" + std::string{query}};
|
||||
StmtPrepareResponseHeader response;
|
||||
deserialize(read_buffer.data() + 1, read_buffer.data() + read_buffer.size(), response);
|
||||
|
||||
// Read the parameters and columns if any
|
||||
auto params = detail::read_fields(stream, response.num_params);
|
||||
auto fields = detail::read_fields(stream, response.num_columns);
|
||||
|
||||
return PreparedStatement<AsyncStream> {
|
||||
stream,
|
||||
response.statement_id,
|
||||
std::move(params),
|
||||
std::move(fields)
|
||||
};
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
template <typename... Params>
|
||||
mysql::BinaryResultset<AsyncStream> mysql::PreparedStatement<AsyncStream>::execute_with_cursor(
|
||||
int4 fetch_count,
|
||||
Params&&... actual_params
|
||||
)
|
||||
{
|
||||
int1 flags = fetch_count == MAX_FETCH_COUNT ? CURSOR_TYPE_NO_CURSOR : CURSOR_TYPE_READ_ONLY;
|
||||
StmtExecute message
|
||||
{
|
||||
statement_id_,
|
||||
flags
|
||||
};
|
||||
detail::fill_execute_msg(message, params_.size(), std::forward<Params>(actual_params)...);
|
||||
return do_execute(message, fetch_count);
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::BinaryResultset<AsyncStream>::read_metadata()
|
||||
{
|
||||
stream_->read(current_packet_);
|
||||
if (get_message_type(current_packet_) == ok_packet_header) // Implicitly checks for errors
|
||||
{
|
||||
process_ok();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Header containing number of fields
|
||||
StmtExecuteResponseHeader response_header;
|
||||
deserialize(current_packet_, response_header);
|
||||
|
||||
// Fields
|
||||
fields_ = detail::read_fields(*stream_, response_header.num_fields);
|
||||
|
||||
// First row
|
||||
retrieve_next();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::BinaryResultset<AsyncStream>::process_ok()
|
||||
{
|
||||
deserialize(current_packet_.data() + 1,
|
||||
current_packet_.data() + current_packet_.size(),
|
||||
ok_packet_);
|
||||
if (cursor_exists() &&
|
||||
!(ok_packet_.status_flags & SERVER_STATUS_LAST_ROW_SENT))
|
||||
{
|
||||
send_fetch();
|
||||
retrieve_next();
|
||||
}
|
||||
else
|
||||
{
|
||||
state_ = State::exhausted;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::BinaryResultset<AsyncStream>::send_fetch()
|
||||
{
|
||||
mysql::StmtFetch msg { statement_id_, fetch_count_ };
|
||||
DynamicBuffer buffer;
|
||||
serialize(buffer, msg);
|
||||
stream_->reset_sequence_number();
|
||||
stream_->write(buffer.get());
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
bool mysql::BinaryResultset<AsyncStream>::retrieve_next()
|
||||
{
|
||||
if (state_ == State::exhausted)
|
||||
return false;
|
||||
|
||||
stream_->read(current_packet_);
|
||||
auto msg_type = get_message_type(current_packet_);
|
||||
if (msg_type == eof_packet_header)
|
||||
{
|
||||
process_ok();
|
||||
}
|
||||
else
|
||||
{
|
||||
detail::deserialize_binary_row(current_packet_, fields_, current_values_);
|
||||
state_ = State::data_available;
|
||||
}
|
||||
return more_data();
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
const mysql::OkPacket& mysql::BinaryResultset<AsyncStream>::ok_packet() const
|
||||
{
|
||||
assert(state_ == State::exhausted ||
|
||||
(state_ == State::data_available && cursor_exists()));
|
||||
return ok_packet_;
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
const std::vector<mysql::BinaryValue>& mysql::BinaryResultset<AsyncStream>::values() const
|
||||
{
|
||||
assert(state_ == State::data_available);
|
||||
return current_values_;
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
mysql::BinaryResultset<AsyncStream> mysql::PreparedStatement<AsyncStream>::do_execute(
|
||||
const StmtExecute& message,
|
||||
int4 fetch_count
|
||||
)
|
||||
{
|
||||
std::vector<std::uint8_t> read_buffer;
|
||||
|
||||
DynamicBuffer write_buffer;
|
||||
serialize(write_buffer, message);
|
||||
stream_->reset_sequence_number();
|
||||
stream_->write(write_buffer.get());
|
||||
|
||||
return mysql::BinaryResultset {*stream_, statement_id_, fetch_count};
|
||||
}
|
||||
|
||||
template <typename AsyncStream>
|
||||
void mysql::PreparedStatement<AsyncStream>::close()
|
||||
{
|
||||
assert(statement_id_ != 0);
|
||||
StmtClose msg { statement_id_ };
|
||||
|
||||
DynamicBuffer write_buffer;
|
||||
serialize(write_buffer, msg);
|
||||
stream_->reset_sequence_number();
|
||||
stream_->write(write_buffer.get());
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user