mirror of
https://github.com/boostorg/mysql.git
synced 2026-02-14 00:42:53 +00:00
318 lines
9.4 KiB
C++
318 lines
9.4 KiB
C++
#ifndef TEST_SERIALIZATION_TEST_COMMON_HPP_
|
|
#define TEST_SERIALIZATION_TEST_COMMON_HPP_
|
|
|
|
#include "mysql/impl/messages.hpp"
|
|
#include "mysql/impl/constants.hpp"
|
|
#include <gtest/gtest.h>
|
|
#include <string>
|
|
#include <any>
|
|
#include <boost/type_index.hpp>
|
|
#include "mysql/impl/serialization.hpp"
|
|
#include "mysql/value.hpp"
|
|
#include "test_common.hpp"
|
|
|
|
namespace mysql
|
|
{
|
|
namespace detail
|
|
{
|
|
|
|
// Operator << for some basic types
|
|
template <std::size_t N>
|
|
std::ostream& operator<<(std::ostream& os, const std::array<char, N>& v)
|
|
{
|
|
return os << v.data();
|
|
}
|
|
|
|
inline std::ostream& operator<<(std::ostream& os, std::uint8_t value)
|
|
{
|
|
return os << +value;
|
|
}
|
|
|
|
using ::date::operator<<;
|
|
|
|
// Operator == for structs
|
|
template <std::size_t index, typename T>
|
|
bool equals_struct(const T& lhs, const T& rhs)
|
|
{
|
|
constexpr auto fields = get_struct_fields<T>::value;
|
|
if constexpr (index == std::tuple_size<decltype(fields)>::value)
|
|
{
|
|
return true;
|
|
}
|
|
else
|
|
{
|
|
constexpr auto pmem = std::get<index>(fields);
|
|
return (rhs.*pmem == lhs.*pmem) && equals_struct<index+1>(lhs, rhs);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
std::enable_if_t<is_struct_with_fields<T>::value, bool>
|
|
operator==(const T& lhs, const T& rhs)
|
|
{
|
|
return equals_struct<0>(lhs, rhs);
|
|
}
|
|
|
|
// Operator << for ValueHolder's
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& os, const value_holder<T>& value)
|
|
{
|
|
return os << value.value;
|
|
}
|
|
|
|
template <typename T>
|
|
std::enable_if_t<std::is_enum_v<T>, std::ostream&>
|
|
operator<<(std::ostream& os, T value)
|
|
{
|
|
return os << boost::typeindex::type_id<T>().pretty_name() << "(" <<
|
|
static_cast<std::underlying_type_t<T>>(value) << ")";
|
|
}
|
|
|
|
// Operator << for structs
|
|
template <std::size_t index, typename T>
|
|
void print_struct(std::ostream& os, const T& value)
|
|
{
|
|
constexpr auto fields = get_struct_fields<T>::value;
|
|
if constexpr (index < std::tuple_size<decltype(fields)>::value)
|
|
{
|
|
constexpr auto pmem = std::get<index>(fields);
|
|
os << " " << (value.*pmem) << ",\n";
|
|
print_struct<index+1>(os, value);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
std::enable_if_t<is_struct_with_fields<T>::value, std::ostream&>
|
|
operator<<(std::ostream& os, const T& value)
|
|
{
|
|
os << boost::typeindex::type_id<T>().pretty_name() << "(\n";
|
|
print_struct<0>(os, value);
|
|
os << ")\n";
|
|
return os;
|
|
}
|
|
|
|
class TypeErasedValue
|
|
{
|
|
public:
|
|
virtual ~TypeErasedValue() {}
|
|
virtual void serialize(SerializationContext& ctx) const = 0;
|
|
virtual std::size_t get_size(const SerializationContext& ctx) const = 0;
|
|
virtual Error deserialize(DeserializationContext& ctx) = 0;
|
|
virtual std::shared_ptr<TypeErasedValue> default_construct() const = 0;
|
|
virtual bool equals(const TypeErasedValue& rhs) const = 0;
|
|
virtual void print(std::ostream& os) const = 0;
|
|
|
|
bool operator==(const TypeErasedValue& rhs) const { return equals(rhs); }
|
|
};
|
|
inline std::ostream& operator<<(std::ostream& os, const TypeErasedValue& value)
|
|
{
|
|
value.print(os);
|
|
return os;
|
|
}
|
|
|
|
template <typename T>
|
|
class TypeErasedValueImpl : public TypeErasedValue
|
|
{
|
|
T value_;
|
|
public:
|
|
TypeErasedValueImpl(const T& v): value_(v) {};
|
|
void serialize(SerializationContext& ctx) const override { ::mysql::detail::serialize(value_, ctx); }
|
|
std::size_t get_size(const SerializationContext& ctx) const override { return ::mysql::detail::get_size(value_, ctx); }
|
|
Error deserialize(DeserializationContext& ctx) override { return ::mysql::detail::deserialize(value_, ctx); }
|
|
std::shared_ptr<TypeErasedValue> default_construct() const override
|
|
{
|
|
return std::make_shared<TypeErasedValueImpl<T>>(T{});
|
|
}
|
|
bool equals(const TypeErasedValue& rhs) const override
|
|
{
|
|
auto typed_value = dynamic_cast<const TypeErasedValueImpl<T>*>(&rhs);
|
|
return typed_value && (typed_value->value_ == value_);
|
|
}
|
|
void print(std::ostream& os) const override
|
|
{
|
|
os << value_;
|
|
}
|
|
};
|
|
|
|
struct SerializeParams : test::named_param
|
|
{
|
|
std::shared_ptr<TypeErasedValue> value;
|
|
std::vector<uint8_t> expected_buffer;
|
|
std::string name;
|
|
capabilities caps;
|
|
std::any additional_storage;
|
|
|
|
template <typename T>
|
|
SerializeParams(const T& v, std::vector<uint8_t>&& buff,
|
|
std::string&& name, std::uint32_t caps=0, std::any storage = {}):
|
|
value(std::make_shared<TypeErasedValueImpl<T>>(v)),
|
|
expected_buffer(move(buff)),
|
|
name(move(name)),
|
|
caps(caps),
|
|
additional_storage(std::move(storage))
|
|
{
|
|
}
|
|
};
|
|
|
|
std::vector<uint8_t> concat(std::vector<uint8_t>&& lhs, const std::vector<uint8_t>& rhs)
|
|
{
|
|
size_t lhs_size = lhs.size();
|
|
std::vector<uint8_t> res (move(lhs));
|
|
res.resize(lhs_size + rhs.size());
|
|
std::memcpy(res.data() + lhs_size, rhs.data(), rhs.size());
|
|
return res;
|
|
}
|
|
|
|
// Test fixtures
|
|
struct SerializationFixture : public testing::TestWithParam<SerializeParams>
|
|
{
|
|
// get_size
|
|
void get_size_test()
|
|
{
|
|
SerializationContext ctx (GetParam().caps, nullptr);
|
|
auto size = GetParam().value->get_size(ctx);
|
|
EXPECT_EQ(size, GetParam().expected_buffer.size());
|
|
}
|
|
|
|
// serialize
|
|
void serialize_test()
|
|
{
|
|
auto expected_size = GetParam().expected_buffer.size();
|
|
std::vector<uint8_t> buffer (expected_size + 8, 0x7a); // buffer overrun detector
|
|
SerializationContext ctx (GetParam().caps, buffer.data());
|
|
GetParam().value->serialize(ctx);
|
|
|
|
// Iterator
|
|
EXPECT_EQ(ctx.first(), buffer.data() + expected_size) << "Iterator not updated correctly";
|
|
|
|
// Buffer
|
|
std::string_view expected_populated = test::makesv(GetParam().expected_buffer.data(), expected_size);
|
|
std::string_view actual_populated = test::makesv(buffer.data(), expected_size);
|
|
test::compare_buffers(expected_populated, actual_populated, "Buffer contents incorrect");
|
|
|
|
// Check for buffer overruns
|
|
std::string expected_clean (8, 0x7a);
|
|
std::string_view actual_clean = test::makesv(buffer.data() + expected_size, 8);
|
|
test::compare_buffers(expected_clean, actual_clean, "Buffer overrun");
|
|
}
|
|
|
|
// deserialize
|
|
void deserialize_test()
|
|
{
|
|
auto first = GetParam().expected_buffer.data();
|
|
auto size = GetParam().expected_buffer.size();
|
|
DeserializationContext ctx (first, first + size, GetParam().caps);
|
|
auto actual_value = GetParam().value->default_construct();
|
|
auto err = actual_value->deserialize(ctx);
|
|
|
|
// No error
|
|
EXPECT_EQ(err, Error::ok);
|
|
|
|
// Iterator advanced
|
|
EXPECT_EQ(ctx.first(), first + size);
|
|
|
|
// Actual value
|
|
EXPECT_EQ(*actual_value, *GetParam().value);
|
|
}
|
|
|
|
void deserialize_extra_space_test()
|
|
{
|
|
std::vector<uint8_t> buffer (GetParam().expected_buffer);
|
|
buffer.push_back(0xff);
|
|
auto first = buffer.data();
|
|
DeserializationContext ctx (first, first + buffer.size(), GetParam().caps);
|
|
auto actual_value = GetParam().value->default_construct();
|
|
auto err = actual_value->deserialize(ctx);
|
|
|
|
// No error
|
|
EXPECT_EQ(err, Error::ok);
|
|
|
|
// Iterator advanced
|
|
EXPECT_EQ(ctx.first(), first + GetParam().expected_buffer.size());
|
|
|
|
// Actual value
|
|
EXPECT_EQ(*actual_value, *GetParam().value);
|
|
}
|
|
|
|
void deserialize_not_enough_space_test()
|
|
{
|
|
std::vector<uint8_t> buffer (GetParam().expected_buffer);
|
|
buffer.back() = 0x7a; // try to detect any overruns
|
|
DeserializationContext ctx (buffer.data(), buffer.data() + buffer.size() - 1, GetParam().caps);
|
|
auto actual_value = GetParam().value->default_construct();
|
|
auto err = actual_value->deserialize(ctx);
|
|
EXPECT_EQ(err, Error::incomplete_message);
|
|
}
|
|
};
|
|
|
|
// Only serialization
|
|
struct SerializeTest : SerializationFixture {};
|
|
TEST_P(SerializeTest, get_size) { get_size_test(); }
|
|
TEST_P(SerializeTest, serialize) { serialize_test(); }
|
|
|
|
// Only deserialization
|
|
struct DeserializeTest : SerializationFixture {};
|
|
TEST_P(DeserializeTest, deserialize) { deserialize_test(); }
|
|
|
|
// Deserialization + extra/infra space
|
|
struct DeserializeSpaceTest : SerializationFixture {};
|
|
TEST_P(DeserializeSpaceTest, deserialize) { deserialize_test(); }
|
|
TEST_P(DeserializeSpaceTest, deserialize_extra_space) { deserialize_extra_space_test(); }
|
|
TEST_P(DeserializeSpaceTest, deserialize_not_enough_space) { deserialize_not_enough_space_test(); }
|
|
|
|
// Serialization + deserialization
|
|
struct SerializeDeserializeTest : SerializationFixture {};
|
|
TEST_P(SerializeDeserializeTest, get_size) { get_size_test(); }
|
|
TEST_P(SerializeDeserializeTest, serialize) { serialize_test(); }
|
|
TEST_P(SerializeDeserializeTest, deserialize) { deserialize_test(); }
|
|
|
|
// All
|
|
struct FullSerializationTest : SerializationFixture {};
|
|
TEST_P(FullSerializationTest, get_size) { get_size_test(); }
|
|
TEST_P(FullSerializationTest, serialize) { serialize_test(); }
|
|
TEST_P(FullSerializationTest, deserialize) { deserialize_test(); }
|
|
TEST_P(FullSerializationTest, deserialize_extra_space) { deserialize_extra_space_test(); }
|
|
TEST_P(FullSerializationTest, deserialize_not_enough_space) { deserialize_not_enough_space_test(); }
|
|
|
|
|
|
// Error tests
|
|
struct DeserializeErrorParams : test::named_param
|
|
{
|
|
std::shared_ptr<TypeErasedValue> value;
|
|
std::vector<uint8_t> buffer;
|
|
std::string name;
|
|
Error expected_error;
|
|
|
|
template <typename T>
|
|
DeserializeErrorParams(
|
|
std::vector<uint8_t>&& buffer,
|
|
std::string&& test_name,
|
|
Error err = Error::incomplete_message
|
|
) :
|
|
value(std::make_shared<TypeErasedValueImpl<T>>(T{})),
|
|
buffer(std::move(buffer)),
|
|
name(std::move(name)),
|
|
expected_error(err)
|
|
{
|
|
}
|
|
};
|
|
|
|
struct DeserializeErrorTest : testing::TestWithParam<DeserializeErrorParams> {};
|
|
|
|
TEST_P(DeserializeErrorTest, Deserialize_ErrorCondition_ReturnsErrorCode)
|
|
{
|
|
auto first = GetParam().buffer.data();
|
|
auto last = GetParam().buffer.data() + GetParam().buffer.size();
|
|
DeserializationContext ctx (first, last, capabilities(0));
|
|
auto value = GetParam().value->default_construct();
|
|
auto err = value->deserialize(ctx);
|
|
EXPECT_EQ(err, GetParam().expected_error);
|
|
}
|
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
#endif /* TEST_SERIALIZATION_TEST_COMMON_HPP_ */
|