diff --git a/include/boost/openmethod/compiler.hpp b/include/boost/openmethod/compiler.hpp index 883c6e0..77ae5e7 100644 --- a/include/boost/openmethod/compiler.hpp +++ b/include/boost/openmethod/compiler.hpp @@ -96,15 +96,12 @@ struct generic_compiler { boost::dynamic_bitset<> used_slots; boost::dynamic_bitset<> reserved_slots; int first_slot = 0; - std::size_t mark = 0; // temporary mark to detect cycles - std::size_t weight = 0; // number of proper direct or indirect bases + std::size_t mark = 0; // temporary mark to detect cycles std::vector vtbl; vptr_type* static_vptr; auto is_base_of(class_* other) const -> bool { - return std::find( - transitive_derived.begin(), transitive_derived.end(), - other) != transitive_derived.end(); + return transitive_derived.find(other) != transitive_derived.end(); } auto vptr() const -> const vptr_type& { @@ -243,6 +240,7 @@ struct compiler : detail::generic_compiler { void resolve_static_type_ids(); void augment_classes(); + void collect_transitive_bases(class_* cls, class_* base); void calculate_transitive_derived(class_& cls); void augment_methods(); void assign_slots(); @@ -352,6 +350,20 @@ void compiler::resolve_static_type_ids() { } } +template +void compiler::collect_transitive_bases(class_* cls, class_* base) { + if (base->mark == class_mark) { + return; + } + + cls->transitive_bases.push_back(base); + base->mark = class_mark; + + for (auto base_base : base->transitive_bases) { + collect_transitive_bases(cls, base_base); + } +} + template void compiler::augment_classes() { using namespace detail; @@ -395,15 +407,14 @@ void compiler::augment_classes() { // map. Collect the bases. for (auto& cr : Policy::classes) { - auto& rtc = class_map[Policy::type_index(cr.type)]; + auto rtc = class_map[Policy::type_index(cr.type)]; - for (auto base_iter = cr.first_base; base_iter != cr.last_base; - ++base_iter) { - auto rtb = class_map[Policy::type_index(*base_iter)]; + for (auto& base : range{cr.first_base, cr.last_base}) { + auto rtb = class_map[Policy::type_index(base)]; if (!rtb) { unknown_class_error error; - error.type = *base_iter; + error.type = base; if constexpr (Policy::template has_facet< policies::error_handler>) { @@ -416,7 +427,8 @@ void compiler::augment_classes() { if (rtc != rtb) { // At compile time we collected the class as its own // improper base, as per std::is_base_of. Eliminate that. - rtc->transitive_bases.push_back(rtb); + ++class_mark; + collect_transitive_bases(rtc, rtb); } } } @@ -437,18 +449,17 @@ void compiler::augment_classes() { } } - // Record the "weight" of the class, i.e. the total number of direct - // and indirect proper bases it has. - rtc.weight = bases.size(); rtc.transitive_bases.swap(bases); } for (auto& rtc : classes) { - // Sort base classes by weight. This ensures that a base class is - // never preceded by one if its own bases classes. + // Sort base classes by number of transitive bases. This ensures that a + // base class is never preceded by one if its own base classes. std::sort( rtc.transitive_bases.begin(), rtc.transitive_bases.end(), - [](auto a, auto b) { return a->weight > b->weight; }); + [](auto a, auto b) { + return a->transitive_bases.size() > b->transitive_bases.size(); + }); mark = ++class_mark; // Collect the direct base classes. The first base is certainly a @@ -490,7 +501,7 @@ void compiler::augment_classes() { indent _(trace); ++trace << "bases: " << rtc.direct_bases << "\n"; ++trace << "derived: " << rtc.direct_derived << "\n"; - ++trace << "covariant: " << rtc.transitive_derived << "\n"; + ++trace << "covariant: " << rtc.transitive_derived << "\n"; } } } @@ -543,7 +554,7 @@ void compiler::augment_methods() { for (auto ti : range{meth_info.vp_begin, meth_info.vp_end}) { auto class_ = class_map[Policy::type_index(ti)]; if (!class_) { - ++trace << "unkown class " << ti << "(" << type_name(ti) + ++trace << "unknown class " << ti << "(" << type_name(ti) << ") for parameter #" << (param_index + 1) << "\n"; unknown_class_error error; error.type = ti; diff --git a/test/test_compiler.cpp b/test/test_compiler.cpp index 6fea4bb..05ca7b6 100644 --- a/test/test_compiler.cpp +++ b/test/test_compiler.cpp @@ -93,7 +93,56 @@ struct F : C, E {}; // ============================================================================ // Test use_classes. -BOOST_AUTO_TEST_CASE(test_use_classes) { +BOOST_AUTO_TEST_CASE(test_use_classes_linear) { + struct Base { + virtual ~Base() = default; + }; + + struct D1 : Base {}; + struct D2 : D1 {}; + struct D3 : D2 {}; + struct D4 : D3 {}; + struct D5 : D4 {}; + + using policy = test_policy_<__COUNTER__>; + + BOOST_OPENMETHOD_CLASSES(Base, D1, D2, D3, policy); + BOOST_OPENMETHOD_CLASSES(D2, D3, policy); + BOOST_OPENMETHOD_CLASSES(D3, D4, policy); + BOOST_OPENMETHOD_CLASSES(D4, D5, D3, policy); + + auto comp = initialize(); + + auto base = get_class(comp); + auto d1 = get_class(comp); + auto d2 = get_class(comp); + auto d3 = get_class(comp); + auto d4 = get_class(comp); + auto d5 = get_class(comp); + + BOOST_CHECK_EQUAL(sstr(base->direct_bases), empty); + BOOST_CHECK_EQUAL(sstr(base->direct_derived), sstr(d1)); + BOOST_CHECK_EQUAL( + sstr(base->transitive_derived), sstr(base, d1, d2, d3, d4, d5)); + + BOOST_CHECK_EQUAL(sstr(d1->direct_derived), sstr(d2)); + BOOST_CHECK_EQUAL(sstr(d1->direct_bases), sstr(base)); + BOOST_CHECK_EQUAL(sstr(d1->transitive_derived), sstr(d1, d2, d3, d4, d5)); + + BOOST_CHECK_EQUAL(sstr(d2->direct_derived), sstr(d3)); + BOOST_CHECK_EQUAL(sstr(d2->direct_bases), sstr(d1)); + BOOST_CHECK_EQUAL(sstr(d2->transitive_derived), sstr(d2, d3, d4, d5)); + + BOOST_CHECK_EQUAL(sstr(d3->direct_derived), sstr(d4)); + BOOST_CHECK_EQUAL(sstr(d3->direct_bases), sstr(d2)); + BOOST_CHECK_EQUAL(sstr(d3->transitive_derived), sstr(d3, d4, d5)); + + BOOST_CHECK_EQUAL(sstr(d4->direct_derived), sstr(d5)); + BOOST_CHECK_EQUAL(sstr(d4->direct_bases), sstr(d3)); + BOOST_CHECK_EQUAL(sstr(d4->transitive_derived), sstr(d4, d5)); +} + +BOOST_AUTO_TEST_CASE(test_use_classes_diamond) { using test_policy = test_policy_<__COUNTER__>; BOOST_OPENMETHOD_REGISTER(use_classes);