diff --git a/CMakeLists.txt b/CMakeLists.txt index a7c496b6..ba36d143 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ target_link_libraries( Boost::system Threads::Threads OpenSSL::Crypto + OpenSSL::SSL ) target_include_directories( mysql_asio diff --git a/TODO.txt b/TODO.txt index b789e92c..14dcfd03 100644 --- a/TODO.txt +++ b/TODO.txt @@ -1,9 +1,16 @@ +SSL + Isolated connection_params class with SSL options + Add an integ test in handshake to verify non-encrypted connections + Add a new integ test that proves using every feature over a non-encrypted connection + Update README + Make ssl context in channel lazy + Create appropriate tests for SSL request + Consider if handshake_response may use default serialization Multiresultset Text protocol Binary protocol (stored procedures) Handshake sha256_password auth plugin: absence may cause BadUser tests to fail in 8.0 - SSL compression Usability Should make_error_code be public? diff --git a/include/boost/mysql/detail/network_algorithms/handshake.hpp b/include/boost/mysql/detail/network_algorithms/handshake.hpp index c7a0e6ad..48592dbb 100644 --- a/include/boost/mysql/detail/network_algorithms/handshake.hpp +++ b/include/boost/mysql/detail/network_algorithms/handshake.hpp @@ -17,6 +17,7 @@ struct handshake_params std::string_view username; std::string_view password; std::string_view database; + bool use_ssl; // TODO: change this to an enum }; template diff --git a/include/boost/mysql/detail/network_algorithms/impl/handshake.hpp b/include/boost/mysql/detail/network_algorithms/impl/handshake.hpp index be42b3e8..00d52ef0 100644 --- a/include/boost/mysql/detail/network_algorithms/impl/handshake.hpp +++ b/include/boost/mysql/detail/network_algorithms/impl/handshake.hpp @@ -16,6 +16,11 @@ inline std::uint8_t get_collation_first_byte(collation value) return static_cast(value) % 0xff; } +inline capabilities conditional_capability(bool condition, std::uint32_t cap) +{ + return capabilities(condition ? cap : 0); +} + inline error_code deserialize_handshake( boost::asio::const_buffer buffer, handshake_packet& output, @@ -77,17 +82,19 @@ class handshake_processor { handshake_params params_; capabilities negotiated_caps_; + handshake_packet initial_packet_; public: handshake_processor(const handshake_params& params): params_(params) {}; - capabilities negotiated_capabilities() const { return negotiated_caps_; } + capabilities negotiated_capabilities() const noexcept { return negotiated_caps_; } + const handshake_params& params() const noexcept { return params_; } + // Initial greeting processing error_code process_capabilities(const handshake_packet& handshake) { capabilities server_caps (handshake.capability_falgs.value); - capabilities required_caps = - params_.database.empty() ? - mandatory_capabilities : - mandatory_capabilities | capabilities(CLIENT_CONNECT_WITH_DB); + capabilities required_caps = mandatory_capabilities | + conditional_capability(!params_.database.empty(), CLIENT_CONNECT_WITH_DB) | + conditional_capability(params_.use_ssl, CLIENT_SSL); if (!server_caps.has_all(required_caps)) { return make_error_code(errc::server_unsupported); @@ -95,66 +102,60 @@ public: negotiated_caps_ = server_caps & (required_caps | optional_capabilities); return error_code(); } - void compose_handshake_response( - std::string_view auth_response, - handshake_response_packet& output - ) - { - output.client_flag.value = negotiated_caps_.get(); - output.max_packet_size.value = MAX_PACKET_SIZE; - output.character_set.value = get_collation_first_byte(params_.connection_collation); - output.username.value = params_.username; - output.auth_response.value = auth_response; - output.database.value = params_.database; - output.client_plugin_name.value = mysql_native_password::plugin_name; - } - error_code compute_auth_switch_response( - const auth_switch_request_packet& request, - auth_switch_response_packet& output, - auth_response_calculator& calc - ) - { - if (request.plugin_name.value != mysql_native_password::plugin_name) - { - return make_error_code(errc::unknown_auth_plugin); - } - return calc.calculate( - params_.password, - request.auth_plugin_data.value, - output.auth_plugin_data.value - ); - } error_code process_handshake(bytestring& buffer, error_info& info) { // Deserialize server greeting - handshake_packet handshake; - auto err = deserialize_handshake(boost::asio::buffer(buffer), handshake, info); + auto err = deserialize_handshake(boost::asio::buffer(buffer), initial_packet_, info); if (err) return err; // Check capabilities - err = process_capabilities(handshake); - if (err) return err; + return process_capabilities(initial_packet_); + } - // Authentication + // Response to that initial greeting + void compose_ssl_request(bytestring& buffer) + { + ssl_request sslreq { + int4(negotiated_capabilities().get()), + int4(MAX_PACKET_SIZE), + int1(get_collation_first_byte(params_.connection_collation)), + {} + }; + + // Serialize and send + serialize_message(sslreq, negotiated_caps_, buffer); + } + + error_code compose_handshake_response(bytestring& buffer) + { + // Authentication calculation auth_response_calculator calc; std::string_view auth_response; - if (handshake.auth_plugin_name.value == mysql_native_password::plugin_name) + if (initial_packet_.auth_plugin_name.value == mysql_native_password::plugin_name) { - err = calc.calculate(params_.password, handshake.auth_plugin_data.value, auth_response); + auto err = calc.calculate(params_.password, initial_packet_.auth_plugin_data.value, auth_response); if (err) return err; } // Compose response - handshake_response_packet response; - compose_handshake_response(auth_response, response); + handshake_response_packet response { + int4(negotiated_caps_.get()), + int4(MAX_PACKET_SIZE), + int1(get_collation_first_byte(params_.connection_collation)), + string_null(params_.username), + string_lenenc(auth_response), + string_null(params_.database), + string_null(mysql_native_password::plugin_name) + }; // Serialize - serialize_message(response, negotiated_capabilities(), buffer); + serialize_message(response, negotiated_caps_, buffer); return error_code(); } + // Server handshake response error_code process_handshake_server_response( bytestring& buffer, bool& auth_complete, @@ -198,6 +199,24 @@ public: return error_code(); } + // Auth switch + error_code compute_auth_switch_response( + const auth_switch_request_packet& request, + auth_switch_response_packet& output, + auth_response_calculator& calc + ) + { + if (request.plugin_name.value != mysql_native_password::plugin_name) + { + return make_error_code(errc::unknown_auth_plugin); + } + return calc.calculate( + params_.password, + request.auth_plugin_data.value, + output.auth_plugin_data.value + ); + } + error_code process_auth_switch_response( boost::asio::const_buffer buffer, error_info& info @@ -238,10 +257,27 @@ void boost::mysql::detail::hanshake( channel.read(channel.shared_buffer(), err); if (err) return; - // Process server greeting + // Process server greeting (handshake) err = processor.process_handshake(channel.shared_buffer(), info); if (err) return; + // Setup SSL if required + if (params.use_ssl) + { + // Send SSL request + processor.compose_ssl_request(channel.shared_buffer()); + channel.write(boost::asio::buffer(channel.shared_buffer()), err); + if (err) return; + + // SSL handshake + channel.ssl_handshake(err); + if (err) return; + } + + // Handshake response + err = processor.compose_handshake_response(channel.shared_buffer()); + if (err) return; + // Send channel.write(boost::asio::buffer(channel.shared_buffer()), err); if (err) return; @@ -253,12 +289,7 @@ void boost::mysql::detail::hanshake( // Process it bool auth_complete = false; err = processor.process_handshake_server_response(channel.shared_buffer(), auth_complete, info); - if (err) return; - if (auth_complete) - { - err.clear(); - return; - } + if (err || auth_complete) return; // We received an auth switch response and we have the response ready to be sent channel.write(boost::asio::buffer(channel.shared_buffer()), err); @@ -344,6 +375,38 @@ boost::mysql::detail::async_handshake( yield break; } + // SSL + if (processor_.params().use_ssl) + { + // Send SSL request + processor_.compose_ssl_request(channel_.shared_buffer()); + yield channel_.async_write( + boost::asio::buffer(channel_.shared_buffer()), + std::move(*this) + ); + if (err) + { + complete(cont, err); + yield break; + } + + // SSL handshake + yield channel_.async_ssl_handshake(std::move(*this)); + if (err) + { + complete(cont, err); + yield break; + } + } + + // Compose handshake response + err = processor_.compose_handshake_response(channel_.shared_buffer()); + if (err) + { + complete(cont, err); + yield break; + } + // Send yield channel_.async_write(boost::asio::buffer(channel_.shared_buffer()), std::move(*this)); if (err) diff --git a/include/boost/mysql/detail/protocol/channel.hpp b/include/boost/mysql/detail/protocol/channel.hpp index 3367b3f8..5d1c6fe7 100644 --- a/include/boost/mysql/detail/protocol/channel.hpp +++ b/include/boost/mysql/detail/protocol/channel.hpp @@ -6,6 +6,10 @@ #include "boost/mysql/detail/protocol/capabilities.hpp" #include #include +#include +#include +#include +#include #include namespace boost { @@ -17,7 +21,9 @@ class channel { // TODO: static asserts for AsyncStream concept // TODO: actually we also require it to be SyncStream, name misleading - AsyncStream& next_layer_; + boost::asio::ssl::context ssl_ctx_; + boost::asio::ssl::stream next_layer_; + bool ssl_active_ {false}; std::uint8_t sequence_number_ {0}; std::array header_buffer_ {}; // for async ops bytestring shared_buff_; // for async ops @@ -28,8 +34,24 @@ class channel error_code process_header_read(std::uint32_t& size_to_read); // reads from header_buffer_ void process_header_write(std::uint32_t size_to_write); // writes to header_buffer_ + + template + std::size_t read_impl(BufferSeq&& buff, error_code& ec); + + template + std::size_t write_impl(BufferSeq&& buff, error_code& ec); + + template + auto async_read_impl(BufferSeq&& buff, CompletionToken&& token); + + template + auto async_write_impl(BufferSeq&& buff, CompletionToken&& token); public: - channel(AsyncStream& stream): next_layer_ {stream} {}; + channel(AsyncStream& stream): + ssl_ctx_(boost::asio::ssl::context::tls_client), + next_layer_ (stream, ssl_ctx_) + { + }; template void read(basic_bytestring& buffer, error_code& code); @@ -44,11 +66,17 @@ public: BOOST_ASIO_INITFN_RESULT_TYPE(CompletionToken, void(error_code)) async_write(boost::asio::const_buffer buffer, CompletionToken&& token); + void ssl_handshake(error_code& ec); + + template + BOOST_ASIO_INITFN_RESULT_TYPE(CompletionToken, void(error_code)) + async_ssl_handshake(CompletionToken&& token); + void reset_sequence_number(std::uint8_t value = 0) { sequence_number_ = value; } std::uint8_t sequence_number() const { return sequence_number_; } using stream_type = AsyncStream; - stream_type& next_layer() { return next_layer_; } + stream_type& next_layer() { return next_layer_.next_layer(); } capabilities current_capabilities() const noexcept { return current_caps_; } void set_current_capabilities(capabilities value) noexcept { current_caps_ = value; } diff --git a/include/boost/mysql/detail/protocol/handshake_messages.hpp b/include/boost/mysql/detail/protocol/handshake_messages.hpp index b130fbaa..2680d9d7 100644 --- a/include/boost/mysql/detail/protocol/handshake_messages.hpp +++ b/include/boost/mysql/detail/protocol/handshake_messages.hpp @@ -79,6 +79,26 @@ struct serialization_traits filler {}; +}; + +template <> +struct get_struct_fields +{ + static constexpr auto value = std::make_tuple( + &ssl_request::client_flag, + &ssl_request::max_packet_size, + &ssl_request::character_set, + &ssl_request::filler + ); +}; + // auth switch request struct auth_switch_request_packet { diff --git a/include/boost/mysql/detail/protocol/impl/channel.hpp b/include/boost/mysql/detail/protocol/impl/channel.hpp index b16faa31..98820db9 100644 --- a/include/boost/mysql/detail/protocol/impl/channel.hpp +++ b/include/boost/mysql/detail/protocol/impl/channel.hpp @@ -1,8 +1,6 @@ #ifndef MYSQL_ASIO_IMPL_CHANNEL_IPP #define MYSQL_ASIO_IMPL_CHANNEL_IPP -#include -#include #include #include #include "boost/mysql/detail/protocol/common_messages.hpp" @@ -73,6 +71,90 @@ void boost::mysql::detail::channel::process_header_write( serialize(header, ctx); } +template +template +std::size_t boost::mysql::detail::channel::read_impl( + BufferSeq&& buff, + error_code& ec +) +{ + if (ssl_active_) + { + return boost::asio::read(next_layer_, std::forward(buff), ec); + } + else + { + return boost::asio::read(next_layer_.next_layer(), std::forward(buff), ec); + } +} + +template +template +std::size_t boost::mysql::detail::channel::write_impl( + BufferSeq&& buff, + error_code& ec +) +{ + if (ssl_active_) + { + return boost::asio::write(next_layer_, std::forward(buff), ec); + } + else + { + return boost::asio::write(next_layer_.next_layer(), std::forward(buff), ec); + } +} + +template +template +auto boost::mysql::detail::channel::async_read_impl( + BufferSeq&& buff, + CompletionToken&& token +) +{ + if (ssl_active_) + { + return boost::asio::async_read( + next_layer_, + std::forward(buff), + std::forward(token) + ); + } + else + { + return boost::asio::async_read( + next_layer_.next_layer(), + std::forward(buff), + std::forward(token) + ); + } +} + +template +template +auto boost::mysql::detail::channel::async_write_impl( + BufferSeq&& buff, + CompletionToken&& token +) +{ + if (ssl_active_) + { + return boost::asio::async_write( + next_layer_, + std::forward(buff), + std::forward(token) + ); + } + else + { + return boost::asio::async_write( + next_layer_.next_layer(), + std::forward(buff), + std::forward(token) + ); + } +} + template template void boost::mysql::detail::channel::read( @@ -87,20 +169,12 @@ void boost::mysql::detail::channel::read( do { - boost::asio::read( - next_layer_, - boost::asio::buffer(header_buffer_), - code - ); + read_impl(boost::asio::buffer(header_buffer_), code); if (code) return; code = process_header_read(size_to_read); if (code) return; buffer.resize(buffer.size() + size_to_read); - boost::asio::read( - next_layer_, - boost::asio::buffer(buffer.data() + transferred_size, size_to_read), - code - ); + read_impl(boost::asio::buffer(buffer.data() + transferred_size, size_to_read), code); if (code) return; transferred_size += size_to_read; } while (size_to_read == MAX_PACKET_SIZE); @@ -122,8 +196,7 @@ void boost::mysql::detail::channel::write( { auto size_to_write = compute_size_to_write(bufsize, transferred_size); process_header_write(size_to_write); - boost::asio::write( - next_layer_, + write_impl( std::array { boost::asio::buffer(header_buffer_), boost::asio::buffer(first + transferred_size, size_to_write) @@ -179,8 +252,7 @@ boost::mysql::detail::channel::async_read( { do { - yield boost::asio::async_read( - stream_.next_layer_, + yield stream_.async_read_impl( boost::asio::buffer(stream_.header_buffer_), std::move(*this) ); @@ -201,8 +273,7 @@ boost::mysql::detail::channel::async_read( buffer_.resize(buffer_.size() + size_to_read); - yield boost::asio::async_read( - stream_.next_layer_, + yield stream_.async_read_impl( boost::asio::buffer(buffer_.data() + total_transferred_size_, size_to_read), std::move(*this) ); @@ -273,8 +344,7 @@ boost::mysql::detail::channel::async_write( size_to_write = compute_size_to_write(buffer_.size(), total_transferred_size_); stream_.process_header_write(size_to_write); - yield boost::asio::async_write( - stream_.next_layer_, + yield stream_.async_write_impl( std::array { boost::asio::buffer(stream_.header_buffer_), boost::asio::buffer(buffer_ + total_transferred_size_, size_to_write) @@ -302,6 +372,29 @@ boost::mysql::detail::channel::async_write( return initiator.result.get(); } +template +void boost::mysql::detail::channel::ssl_handshake( + error_code& ec +) +{ + ssl_active_ = true; + next_layer_.handshake(boost::asio::ssl::stream_base::client, ec); +} + +template +template +BOOST_ASIO_INITFN_RESULT_TYPE(CompletionToken, void(boost::mysql::error_code)) +boost::mysql::detail::channel::async_ssl_handshake( + CompletionToken&& token +) +{ + ssl_active_ = true; + return next_layer_.async_handshake( + boost::asio::ssl::stream_base::client, + std::forward(token) + ); +} + #include diff --git a/include/boost/mysql/impl/connection.hpp b/include/boost/mysql/impl/connection.hpp index bcc6c4d0..da7f548f 100644 --- a/include/boost/mysql/impl/connection.hpp +++ b/include/boost/mysql/impl/connection.hpp @@ -19,7 +19,8 @@ inline handshake_params to_handshake_params( input.connection_collation, input.username, input.password, - input.database + input.database, + true }; } diff --git a/test/unit/detail/protocol/channel.cpp b/test/unit/detail/protocol/channel.cpp index 1d50b993..339c3947 100644 --- a/test/unit/detail/protocol/channel.cpp +++ b/test/unit/detail/protocol/channel.cpp @@ -96,6 +96,12 @@ public: } return res; } + + using lowest_layer_type = MockStream; + using executor_type = boost::asio::system_executor; + + MockStream& lowest_layer() { return *this; } + executor_type get_executor() { return boost::asio::system_executor(); } }; struct MysqlChannelFixture : public Test