fix inheritance lattice deduction

This commit is contained in:
Jean-Louis Leroy
2025-05-12 16:53:05 -04:00
committed by Jean-Louis Leroy
parent e8ec2611cb
commit 40c90777aa
2 changed files with 80 additions and 20 deletions

View File

@@ -96,15 +96,12 @@ struct generic_compiler {
boost::dynamic_bitset<> used_slots;
boost::dynamic_bitset<> reserved_slots;
int first_slot = 0;
std::size_t mark = 0; // temporary mark to detect cycles
std::size_t weight = 0; // number of proper direct or indirect bases
std::size_t mark = 0; // temporary mark to detect cycles
std::vector<vtbl_entry> vtbl;
vptr_type* static_vptr;
auto is_base_of(class_* other) const -> bool {
return std::find(
transitive_derived.begin(), transitive_derived.end(),
other) != transitive_derived.end();
return transitive_derived.find(other) != transitive_derived.end();
}
auto vptr() const -> const vptr_type& {
@@ -243,6 +240,7 @@ struct compiler : detail::generic_compiler {
void resolve_static_type_ids();
void augment_classes();
void collect_transitive_bases(class_* cls, class_* base);
void calculate_transitive_derived(class_& cls);
void augment_methods();
void assign_slots();
@@ -352,6 +350,20 @@ void compiler<Policy>::resolve_static_type_ids() {
}
}
template<class Policy>
void compiler<Policy>::collect_transitive_bases(class_* cls, class_* base) {
if (base->mark == class_mark) {
return;
}
cls->transitive_bases.push_back(base);
base->mark = class_mark;
for (auto base_base : base->transitive_bases) {
collect_transitive_bases(cls, base_base);
}
}
template<class Policy>
void compiler<Policy>::augment_classes() {
using namespace detail;
@@ -395,15 +407,14 @@ void compiler<Policy>::augment_classes() {
// map. Collect the bases.
for (auto& cr : Policy::classes) {
auto& rtc = class_map[Policy::type_index(cr.type)];
auto rtc = class_map[Policy::type_index(cr.type)];
for (auto base_iter = cr.first_base; base_iter != cr.last_base;
++base_iter) {
auto rtb = class_map[Policy::type_index(*base_iter)];
for (auto& base : range{cr.first_base, cr.last_base}) {
auto rtb = class_map[Policy::type_index(base)];
if (!rtb) {
unknown_class_error error;
error.type = *base_iter;
error.type = base;
if constexpr (Policy::template has_facet<
policies::error_handler>) {
@@ -416,7 +427,8 @@ void compiler<Policy>::augment_classes() {
if (rtc != rtb) {
// At compile time we collected the class as its own
// improper base, as per std::is_base_of. Eliminate that.
rtc->transitive_bases.push_back(rtb);
++class_mark;
collect_transitive_bases(rtc, rtb);
}
}
}
@@ -437,18 +449,17 @@ void compiler<Policy>::augment_classes() {
}
}
// Record the "weight" of the class, i.e. the total number of direct
// and indirect proper bases it has.
rtc.weight = bases.size();
rtc.transitive_bases.swap(bases);
}
for (auto& rtc : classes) {
// Sort base classes by weight. This ensures that a base class is
// never preceded by one if its own bases classes.
// Sort base classes by number of transitive bases. This ensures that a
// base class is never preceded by one if its own base classes.
std::sort(
rtc.transitive_bases.begin(), rtc.transitive_bases.end(),
[](auto a, auto b) { return a->weight > b->weight; });
[](auto a, auto b) {
return a->transitive_bases.size() > b->transitive_bases.size();
});
mark = ++class_mark;
// Collect the direct base classes. The first base is certainly a
@@ -490,7 +501,7 @@ void compiler<Policy>::augment_classes() {
indent _(trace);
++trace << "bases: " << rtc.direct_bases << "\n";
++trace << "derived: " << rtc.direct_derived << "\n";
++trace << "covariant: " << rtc.transitive_derived << "\n";
++trace << "covariant: " << rtc.transitive_derived << "\n";
}
}
}
@@ -543,7 +554,7 @@ void compiler<Policy>::augment_methods() {
for (auto ti : range{meth_info.vp_begin, meth_info.vp_end}) {
auto class_ = class_map[Policy::type_index(ti)];
if (!class_) {
++trace << "unkown class " << ti << "(" << type_name(ti)
++trace << "unknown class " << ti << "(" << type_name(ti)
<< ") for parameter #" << (param_index + 1) << "\n";
unknown_class_error error;
error.type = ti;

View File

@@ -93,7 +93,56 @@ struct F : C, E {};
// ============================================================================
// Test use_classes.
BOOST_AUTO_TEST_CASE(test_use_classes) {
BOOST_AUTO_TEST_CASE(test_use_classes_linear) {
struct Base {
virtual ~Base() = default;
};
struct D1 : Base {};
struct D2 : D1 {};
struct D3 : D2 {};
struct D4 : D3 {};
struct D5 : D4 {};
using policy = test_policy_<__COUNTER__>;
BOOST_OPENMETHOD_CLASSES(Base, D1, D2, D3, policy);
BOOST_OPENMETHOD_CLASSES(D2, D3, policy);
BOOST_OPENMETHOD_CLASSES(D3, D4, policy);
BOOST_OPENMETHOD_CLASSES(D4, D5, D3, policy);
auto comp = initialize<policy>();
auto base = get_class<Base>(comp);
auto d1 = get_class<D1>(comp);
auto d2 = get_class<D2>(comp);
auto d3 = get_class<D3>(comp);
auto d4 = get_class<D4>(comp);
auto d5 = get_class<D5>(comp);
BOOST_CHECK_EQUAL(sstr(base->direct_bases), empty);
BOOST_CHECK_EQUAL(sstr(base->direct_derived), sstr(d1));
BOOST_CHECK_EQUAL(
sstr(base->transitive_derived), sstr(base, d1, d2, d3, d4, d5));
BOOST_CHECK_EQUAL(sstr(d1->direct_derived), sstr(d2));
BOOST_CHECK_EQUAL(sstr(d1->direct_bases), sstr(base));
BOOST_CHECK_EQUAL(sstr(d1->transitive_derived), sstr(d1, d2, d3, d4, d5));
BOOST_CHECK_EQUAL(sstr(d2->direct_derived), sstr(d3));
BOOST_CHECK_EQUAL(sstr(d2->direct_bases), sstr(d1));
BOOST_CHECK_EQUAL(sstr(d2->transitive_derived), sstr(d2, d3, d4, d5));
BOOST_CHECK_EQUAL(sstr(d3->direct_derived), sstr(d4));
BOOST_CHECK_EQUAL(sstr(d3->direct_bases), sstr(d2));
BOOST_CHECK_EQUAL(sstr(d3->transitive_derived), sstr(d3, d4, d5));
BOOST_CHECK_EQUAL(sstr(d4->direct_derived), sstr(d5));
BOOST_CHECK_EQUAL(sstr(d4->direct_bases), sstr(d3));
BOOST_CHECK_EQUAL(sstr(d4->transitive_derived), sstr(d4, d5));
}
BOOST_AUTO_TEST_CASE(test_use_classes_diamond) {
using test_policy = test_policy_<__COUNTER__>;
BOOST_OPENMETHOD_REGISTER(use_classes<A, B, AB, C, D, E, test_policy>);