diff --git a/include/mysql/impl/mysql_channel_impl.hpp b/include/mysql/impl/mysql_channel_impl.hpp index e3dfad65..6aa198bb 100644 --- a/include/mysql/impl/mysql_channel_impl.hpp +++ b/include/mysql/impl/mysql_channel_impl.hpp @@ -68,7 +68,7 @@ void mysql::detail::MysqlChannel::process_header_write( { msgs::packet_header header; header.packet_size.value = size_to_write; - header.sequence_number = next_sequence_number(); + header.sequence_number.value = next_sequence_number(); SerializationContext ctx (0, header_buffer_.data()); // capabilities not relevant here serialize(header, ctx); } @@ -114,6 +114,7 @@ void mysql::detail::MysqlChannel::write( { std::size_t transferred_size = 0; auto bufsize = buffer.size(); + auto first = static_cast(buffer.data()); while (transferred_size < bufsize) { @@ -123,7 +124,7 @@ void mysql::detail::MysqlChannel::write( next_layer_, std::array { boost::asio::buffer(header_buffer_), - boost::asio::buffer(buffer + transferred_size, size_to_write) + boost::asio::buffer(first + transferred_size, size_to_write) }, errc ); diff --git a/test/mysql_channel.cpp b/test/mysql_channel.cpp index d125c7e3..85c30c78 100644 --- a/test/mysql_channel.cpp +++ b/test/mysql_channel.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using namespace testing; using namespace mysql; @@ -20,15 +21,38 @@ namespace errc = boost::system::errc; namespace { +void concat(std::vector& lhs, boost::asio::const_buffer rhs) +{ + auto current_size = lhs.size(); + lhs.resize(current_size + rhs.size()); + memcpy(lhs.data() + current_size, rhs.data(), rhs.size()); +} + +void concat(std::vector& lhs, const std::vector& rhs) +{ + concat(lhs, boost::asio::buffer(rhs)); +} + class MockStream { public: MOCK_METHOD2(read_buffer, std::size_t(boost::asio::mutable_buffer, mysql::error_code&)); + MOCK_METHOD2(write_buffer, std::size_t(boost::asio::const_buffer, mysql::error_code&)); + + MockStream() + { + ON_CALL(*this, read_buffer).WillByDefault(SetArgReferee<1>(errc::make_error_code(errc::timed_out))); + ON_CALL(*this, write_buffer).WillByDefault(SetArgReferee<1>(errc::make_error_code(errc::timed_out))); + } template std::size_t read_some(MutableBufferSequence mb, mysql::error_code& ec) { - if (buffer_size(mb) == 0) return 0; + if (buffer_size(mb) == 0) + { + ec.clear(); + return 0; + } size_t res = 0; for (auto it = buffer_sequence_begin(mb); it != buffer_sequence_end(mb); ++it) res += read_buffer(*it, ec); @@ -46,17 +70,52 @@ public: } return res; } + + template + std::size_t write_some(ConstBufferSequence cb, mysql::error_code& ec) + { + if (buffer_size(cb) == 0) + { + ec.clear(); + return 0; + } + size_t res = 0; + for (auto it = buffer_sequence_begin(cb); it != buffer_sequence_end(cb); ++it) + { + auto written = write_buffer(*it, ec); + res += written; + if (written < it->size()) break; + } + return res; + } + + template + std::size_t write_some(ConstBufferSequence mb) + { + error_code ec; + auto res = write_some(mb, ec); + if (res) + { + throw boost::system::system_error(ec); + } + return res; + } }; - -struct MysqlChannelTest : public Test +struct MysqlChannelFixture : public Test { using Channel = MysqlChannel; MockStream stream; Channel channel {stream}; + mysql::error_code errc; + InSequence seq; +}; + + +struct MysqlChannelReadTest : public MysqlChannelFixture +{ std::vector mem; dynamic_vector_buffer> buffer {mem}; - mysql::error_code errc; void verify_buffer(const std::vector& expected) { @@ -82,7 +141,7 @@ struct MysqlChannelTest : public Test } }; -TEST_F(MysqlChannelTest, SyncRead_AllReadsSuccessful_ReadHeaderPopulatesBuffer) +TEST_F(MysqlChannelReadTest, SyncRead_AllReadsSuccessful_ReadHeaderPopulatesBuffer) { EXPECT_CALL(stream, read_buffer) .WillOnce(Invoke(buffer_copier({0x03, 0x00, 0x00, 0x00}))) @@ -92,7 +151,9 @@ TEST_F(MysqlChannelTest, SyncRead_AllReadsSuccessful_ReadHeaderPopulatesBuffer) verify_buffer({0xfe, 0x03, 0x02}); } -TEST_F(MysqlChannelTest, SyncRead_MoreThan16M_JoinsPackets) +// TODO: test with existing seqnum, resetting seqnum, wrapping seqnum + +TEST_F(MysqlChannelReadTest, SyncRead_MoreThan16M_JoinsPackets) { EXPECT_CALL(stream, read_buffer) .WillOnce(Invoke(buffer_copier({0xff, 0xff, 0xff, 0x00}))) @@ -106,7 +167,7 @@ TEST_F(MysqlChannelTest, SyncRead_MoreThan16M_JoinsPackets) verify_buffer(std::vector(0xffffff * 2 + 4, 0x20)); } -TEST_F(MysqlChannelTest, SyncRead_ShortReads_InvokesReadAgain) +TEST_F(MysqlChannelReadTest, SyncRead_ShortReads_InvokesReadAgain) { EXPECT_CALL(stream, read_buffer) .WillOnce(Invoke(buffer_copier({0x04}))) @@ -118,7 +179,7 @@ TEST_F(MysqlChannelTest, SyncRead_ShortReads_InvokesReadAgain) verify_buffer({0x01, 0x02, 0x03, 0x04}); } -TEST_F(MysqlChannelTest, SyncRead_ReadErrorInHeader_ReturnsFailureErrorCode) +TEST_F(MysqlChannelReadTest, SyncRead_ReadErrorInHeader_ReturnsFailureErrorCode) { auto expected_error = errc::make_error_code(errc::not_supported); EXPECT_CALL(stream, read_buffer) @@ -127,7 +188,7 @@ TEST_F(MysqlChannelTest, SyncRead_ReadErrorInHeader_ReturnsFailureErrorCode) EXPECT_EQ(errc, expected_error); } -TEST_F(MysqlChannelTest, SyncRead_ReadErrorInPacket_ReturnsFailureErrorCode) +TEST_F(MysqlChannelReadTest, SyncRead_ReadErrorInPacket_ReturnsFailureErrorCode) { auto expected_error = errc::make_error_code(errc::not_supported); EXPECT_CALL(stream, read_buffer) @@ -137,7 +198,7 @@ TEST_F(MysqlChannelTest, SyncRead_ReadErrorInPacket_ReturnsFailureErrorCode) EXPECT_EQ(errc, expected_error); } -TEST_F(MysqlChannelTest, SyncRead_SequenceNumberMismatch_ReturnsAppropriateErrorCode) +TEST_F(MysqlChannelReadTest, SyncRead_SequenceNumberMismatch_ReturnsAppropriateErrorCode) { EXPECT_CALL(stream, read_buffer) .WillOnce(Invoke(buffer_copier({0xff, 0xff, 0xff, 0x05}))); @@ -145,7 +206,91 @@ TEST_F(MysqlChannelTest, SyncRead_SequenceNumberMismatch_ReturnsAppropriateError EXPECT_EQ(errc, make_error_code(Error::sequence_number_mismatch)); } +struct MysqlChannelWriteTest : public MysqlChannelFixture +{ + std::vector bytes_written; + void verify_buffer(const std::vector& expected) + { + EXPECT_EQ(bytes_written, expected); + } + + auto make_write_handler(std::size_t max_bytes_written = 0xffffffff) + { + return [this, max_bytes_written](boost::asio::const_buffer buff, error_code& ec) { + auto actual_size = std::min(buff.size(), max_bytes_written); + concat(bytes_written, boost::asio::buffer(buff.data(), actual_size)); + ec.clear(); + return actual_size; + }; + } + + static auto write_failer(errc::errc_t error) + { + return [error](boost::asio::const_buffer buff, error_code& ec) { + ec = errc::make_error_code(error); + return 0; + }; + } +}; + +TEST_F(MysqlChannelWriteTest, SyncWrite_AllWritesSuccessful_WritesHeaderAndBuffer) +{ + ON_CALL(stream, write_buffer) + .WillByDefault(Invoke(make_write_handler())); + channel.write(buffer(std::vector{0xaa, 0xab, 0xac}), errc); + verify_buffer({ + 0x03, 0x00, 0x00, 0x00, // header + 0xaa, 0xab, 0xac // body + }); + EXPECT_EQ(errc, error_code()); +} + +TEST_F(MysqlChannelWriteTest, SyncWrite_MoreThan16M_SplitsInPackets) +{ + ON_CALL(stream, write_buffer) + .WillByDefault(Invoke(make_write_handler())); + channel.write(buffer(std::vector(2*0xffffff + 4, 0xab)), errc); + std::vector expected_buffer {0xff, 0xff, 0xff, 0x00}; + concat(expected_buffer, std::vector(0xffffff, 0xab)); + concat(expected_buffer, {0xff, 0xff, 0xff, 0x01}); + concat(expected_buffer, std::vector(0xffffff, 0xab)); + concat(expected_buffer, {0x04, 0x00, 0x00, 0x02}); + concat(expected_buffer, std::vector(4, 0xab)); + verify_buffer(expected_buffer); + EXPECT_EQ(errc, error_code()); +} + +TEST_F(MysqlChannelWriteTest, SyncWrite_ShortWrites_WritesHeaderAndBuffer) +{ + ON_CALL(stream, write_buffer) + .WillByDefault(Invoke(make_write_handler(2))); + channel.write(buffer(std::vector{0xaa, 0xab, 0xac}), errc); + verify_buffer({ + 0x03, 0x00, 0x00, 0x00, // header + 0xaa, 0xab, 0xac // body + }); + EXPECT_EQ(errc, error_code()); +} + +TEST_F(MysqlChannelWriteTest, SyncWrite_WriteErrorInHeader_ReturnsErrorCode) +{ + ON_CALL(stream, write_buffer) + .WillByDefault(Invoke(write_failer(errc::broken_pipe))); + channel.write(buffer(std::vector(10, 0x01)), errc); + EXPECT_EQ(errc, errc::make_error_code(errc::broken_pipe)); +} + +TEST_F(MysqlChannelWriteTest, SyncWrite_WriteErrorInPacket_ReturnsErrorCode) +{ + EXPECT_CALL(stream, write_buffer) + .WillOnce(Return(4)) + .WillOnce(Invoke(write_failer(errc::broken_pipe))); + channel.write(buffer(std::vector(10, 0x01)), errc); + EXPECT_EQ(errc, errc::make_error_code(errc::broken_pipe)); } +} // anon namespace + +