From a4ea007d595da2fc5d20ea8e53105bbd9e164f22 Mon Sep 17 00:00:00 2001 From: Zach Laine Date: Tue, 29 Nov 2016 17:36:17 -0600 Subject: [PATCH] Add support for alternate transform function syntax, a la Boost.Proto, using the new tag types. --- detail/default_eval.hpp | 107 +++++++++++++++++++-- test/call_expr.cpp | 135 +++++++++++++++++++++++++++ test/user_expression_transform_3.cpp | 99 ++++++++++++++++++-- 3 files changed, 327 insertions(+), 14 deletions(-) diff --git a/detail/default_eval.hpp b/detail/default_eval.hpp index 22f9cf7..74284f8 100644 --- a/detail/default_eval.hpp +++ b/detail/default_eval.hpp @@ -164,11 +164,8 @@ namespace boost::proto17 { constexpr expr_kind kind = remove_cv_ref_t::kind; if constexpr (kind == expr_kind::expr_ref) { decltype(auto) ref = ::boost::proto17::value(expr); - default_transform_expression< - decltype(ref), - Transform, - detail::arity_of() - > transformer; + constexpr expr_kind kind = remove_cv_ref_t::kind; + default_transform_expression()> transformer; return transformer(ref, static_cast(transform)); } else if constexpr (kind == expr_kind::terminal || kind == expr_kind::placeholder) { return static_cast(expr); @@ -198,10 +195,105 @@ namespace boost::proto17 { std::void_t()(std::declval()))> > { - auto operator() (Expr && expr, Transform && transform) + decltype(auto) operator() (Expr && expr, Transform && transform) { return static_cast(transform)(static_cast(expr)); } }; + template + struct default_transform_expression< + Expr, + Transform, + expr_arity::one, + std::void_t()( + detail::tag_for::kind>(), + deref(::boost::proto17::value(std::declval())) + ) + )> + > + { + decltype(auto) operator() (Expr && expr, Transform && transform) + { + return static_cast(transform)( + detail::tag_for::kind>(), + deref(::boost::proto17::value(static_cast(expr))) + ); + } + }; + + template + struct default_transform_expression< + Expr, + Transform, + expr_arity::two, + std::void_t()( + detail::tag_for::kind>(), + deref(::boost::proto17::left(std::declval())), + deref(::boost::proto17::right(std::declval())) + ) + )> + > + { + decltype(auto) operator() (Expr && expr, Transform && transform) + { + return static_cast(transform)( + detail::tag_for::kind>(), + deref(::boost::proto17::left(static_cast(expr))), + deref(::boost::proto17::right(static_cast(expr))) + ); + } + }; + + template + struct transform_call_unpacker + { + template + auto operator() (Tuple && tuple, TransformT && transform, std::integer_sequence) + -> decltype( + static_cast(transform)( + detail::tag_for(), + static_cast(tuple)[hana::llong_c]... + ) + ) + { + return static_cast(transform)( + detail::tag_for(), + static_cast(tuple)[hana::llong_c]...); + } + }; + + template + constexpr auto indices_for (Expr const & expr) + { + constexpr long long size = decltype(hana::size(expr))::value; + return std::make_integer_sequence(); + } + + template + struct default_transform_expression< + Expr, + Transform, + expr_arity::n, + std::void_t::kind>{}( + std::declval().elements, + std::declval(), + indices_for() + ) + )> + > + { + decltype(auto) operator() (Expr && expr, Transform && transform) + { + return transform_call_unpacker::kind>{}( + static_cast(expr).elements, + static_cast(transform), + indices_for() + ); + } + }; + template < template class ExprTemplate, expr_kind Kind, @@ -218,7 +310,8 @@ namespace boost::proto17 { static_cast(tuple), [&transform](auto && element) { using element_t = decltype(element); - default_transform_expression()> transformer; + constexpr expr_kind kind = remove_cv_ref_t::kind; + default_transform_expression()> transformer; return transformer( static_cast(element), static_cast(transform) diff --git a/test/call_expr.cpp b/test/call_expr.cpp index 060def5..ad83496 100644 --- a/test/call_expr.cpp +++ b/test/call_expr.cpp @@ -33,6 +33,94 @@ namespace user { inline number tag_function (double a, double b) { return number{a + b}; } + struct eval_xform_tag + { + decltype(auto) operator() (bp17::call_tag, tag_type, double a, double b) + { return tag_function(a, b); } + }; + + struct empty_xform {}; + + struct eval_xform_expr + { + decltype(auto) operator() ( + bp17::expression< + bp17::expr_kind::call, + bh::tuple< + bp17::expression_ref >, + term, + term + > + > const & expr + ) { + using namespace boost::hana::literals; + return tag_function( + (double)bp17::value(expr.elements[1_c]).value, + (double)bp17::value(expr.elements[2_c]) + ); + } + + decltype(auto) operator() ( + bp17::expression< + bp17::expr_kind::call, + bh::tuple< + bp17::expression_ref >, + bp17::expression_ref>, + term + > + > const & expr + ) { + using namespace boost::hana::literals; + return tag_function( + (double)bp17::deref(expr.elements[1_c]).value, + (double)bp17::value(expr.elements[2_c]) + ); + } + }; + + struct eval_xform_both + { + decltype(auto) operator() (bp17::call_tag, tag_type, double a, double b) + { + throw std::logic_error("Oops! Picked the wrong overload!"); + return tag_function(a, b); + } + + decltype(auto) operator() ( + bp17::expression< + bp17::expr_kind::call, + bh::tuple< + bp17::expression_ref >, + term, + term + > + > const & expr + ) { + using namespace boost::hana::literals; + return tag_function( + (double)bp17::value(expr.elements[1_c]).value, + (double)bp17::value(expr.elements[2_c]) + ); + } + + decltype(auto) operator() ( + bp17::expression< + bp17::expr_kind::call, + bh::tuple< + bp17::expression_ref >, + bp17::expression_ref>, + term + > + > const & expr + ) { + using namespace boost::hana::literals; + return tag_function( + (double)bp17::deref(expr.elements[1_c]).value, + (double)bp17::value(expr.elements[2_c]) + ); + } + }; + template inline auto eval_call (tag_type, T && ...t) { @@ -211,6 +299,53 @@ TEST(call_expr, test_call_expr) EXPECT_EQ(result.value, 1); } } + + { + auto plus = bp17::make_terminal(user::tag_type{}); + auto expr = plus(user::number{13}, 1); + + { + auto transformed_expr = transform(expr, user::empty_xform{}); + user::number result = transformed_expr; + EXPECT_EQ(result.value, 14); + } + + { + user::number result = transform(expr, user::eval_xform_tag{}); + EXPECT_EQ(result.value, 14); + } + + { + user::number result = transform(expr, user::eval_xform_expr{}); + EXPECT_EQ(result.value, 14); + } + + { + user::number result = transform(expr, user::eval_xform_both{}); + EXPECT_EQ(result.value, 14); + } + } + + { + auto plus = bp17::make_terminal(user::tag_type{}); + auto thirteen = bp17::make_terminal(user::number{13}); + auto expr = plus(thirteen, 1); + + { + user::number result = transform(expr, user::eval_xform_tag{}); + EXPECT_EQ(result.value, 14); + } + + { + user::number result = transform(expr, user::eval_xform_expr{}); + EXPECT_EQ(result.value, 14); + } + + { + user::number result = transform(expr, user::eval_xform_both{}); + EXPECT_EQ(result.value, 14); + } + } } { diff --git a/test/user_expression_transform_3.cpp b/test/user_expression_transform_3.cpp index c207eb8..732a42c 100644 --- a/test/user_expression_transform_3.cpp +++ b/test/user_expression_transform_3.cpp @@ -29,17 +29,74 @@ namespace user { number naxpy (number a, number x, number y) { return number{a.value * x.value + y.value + 10.0}; } - struct eval_xform + struct empty_xform {}; + + struct eval_xform_tag { - decltype(auto) operator() (term const & expr) - { return expr.value(); } + decltype(auto) operator() (bp17::terminal_tag, user::number const & n) + { return n; } }; - struct plus_to_minus_xform + struct eval_xform_expr + { + decltype(auto) operator() (term const & expr) + { return ::boost::proto17::value(expr); } + }; + + struct eval_xform_both + { + decltype(auto) operator() (bp17::terminal_tag, user::number const & n) + { return n; } + + decltype(auto) operator() (term const & expr) + { + throw std::logic_error("Oops! Picked the wrong overload!"); + return ::boost::proto17::value(expr); + } + }; + + struct plus_to_minus_xform_tag + { + decltype(auto) operator() (bp17::plus_tag, user::number const & lhs, user::number const & rhs) + { + return bp17::make_expression( + term{lhs}, + term{rhs} + ); + } + }; + + struct plus_to_minus_xform_expr { template decltype(auto) operator() (bp17::expression> const & expr) - { return bp17::make_expression(expr.left(), expr.right()); } + { + return bp17::make_expression( + ::boost::proto17::left(expr), + ::boost::proto17::right(expr) + ); + } + }; + + struct plus_to_minus_xform_both + { + decltype(auto) operator() (bp17::plus_tag, user::number const & lhs, user::number const & rhs) + { + return bp17::make_expression( + term{lhs}, + term{rhs} + ); + } + + template + decltype(auto) operator() (bp17::expression> const & expr) + { + throw std::logic_error("Oops! Picked the wrong overload!"); + return bp17::make_expression( + ::boost::proto17::left(expr), + ::boost::proto17::right(expr) + ); + } }; decltype(auto) naxpy_eager_nontemplate_xform ( @@ -125,8 +182,24 @@ TEST(user_expression_transform_3, test_user_expression_transform_3) EXPECT_EQ(result.value, 1); } - auto transformed_expr = transform(expr, user::eval_xform{}); { + auto transformed_expr = transform(expr, user::empty_xform{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 1); + } + + { + auto transformed_expr = transform(expr, user::eval_xform_tag{}); + EXPECT_EQ(transformed_expr.value, 1); + } + + { + auto transformed_expr = transform(expr, user::eval_xform_expr{}); + EXPECT_EQ(transformed_expr.value, 1); + } + + { + auto transformed_expr = transform(expr, user::eval_xform_both{}); EXPECT_EQ(transformed_expr.value, 1); } } @@ -138,8 +211,20 @@ TEST(user_expression_transform_3, test_user_expression_transform_3) EXPECT_EQ(result.value, 45); } - auto transformed_expr = transform(expr, user::plus_to_minus_xform{}); { + auto transformed_expr = transform(expr, user::plus_to_minus_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 39); + } + + { + auto transformed_expr = transform(expr, user::plus_to_minus_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 39); + } + + { + auto transformed_expr = transform(expr, user::plus_to_minus_xform_both{}); user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 39); }