diff --git a/detail/default_eval.hpp b/detail/default_eval.hpp index 93aaced..2e8df57 100644 --- a/detail/default_eval.hpp +++ b/detail/default_eval.hpp @@ -156,6 +156,39 @@ namespace boost::proto17 { } } + template > + struct default_transform_expression + { + auto operator() (Expr && expr, Transform && transform) + { + constexpr expr_kind kind = kind_of>::value; + if constexpr (kind == expr_kind::terminal || kind == expr_kind::placeholder) { + return expr; + } else { + auto tuple = hana::transform( + expr.elements, + [&transform](auto && element) { + default_transform_expression transformer; + return transformer(element, static_cast(transform)); + } + ); + using return_type = typename expression_from_tuple::type; + return return_type(std::move(tuple)); + } + } + }; + + template + struct default_transform_expression< + Expr, + Transform, + std::void_t()(std::declval()))> + > + { + auto operator() (Expr && expr, Transform && transform) + { return static_cast(transform)(static_cast(expr)); } + }; + } } diff --git a/detail/expression.hpp b/detail/expression.hpp index b4afc0a..991ba61 100644 --- a/detail/expression.hpp +++ b/detail/expression.hpp @@ -54,15 +54,22 @@ namespace boost::proto17 { struct is_expr> { static bool const value = true; }; + template + struct remove_cv_ref : std::remove_cv> + {}; + + template + using remove_cv_ref_t = typename remove_cv_ref::type; + template ::type, bool RemoveRefs = std::is_rvalue_reference_v, - bool IsExpr = is_expr>::value> + bool IsExpr = is_expr>::value> struct operand_type; template struct operand_type - { using type = std::remove_cv_t>; }; + { using type = remove_cv_ref_t; }; template struct operand_type @@ -72,17 +79,17 @@ namespace boost::proto17 { struct operand_type { using type = terminal; }; - template - struct call_expression_from_tuple; + template + struct expression_from_tuple; - template - struct call_expression_from_tuple> - { using type = expression; }; + template + struct expression_from_tuple> + { using type = expression; }; template constexpr auto make_call_expression (T && ...args) { - return typename call_expression_from_tuple::type{ + return typename expression_from_tuple::type{ Tuple{static_cast(args)...} }; } diff --git a/expression.hpp b/expression.hpp index 1e2eecf..49da47d 100644 --- a/expression.hpp +++ b/expression.hpp @@ -14,8 +14,7 @@ namespace boost::proto17 { namespace adl_detail { template - constexpr decltype(auto) eval_expression_as (E const & expr, hana::basic_type, T &&... args) - { return static_cast(detail::default_eval_expr(expr, static_cast(args)...)); } + constexpr decltype(auto) eval_expression_as (E const & expr, hana::basic_type, T &&... args); struct eval_expression_as_fn { @@ -34,15 +33,14 @@ namespace boost::proto17 { } - // TODO: static assert/SFINAE sizeof...(T) >= highest-indexed placeholder + 1 template - decltype(auto) evaluate (Expr const & expr, T && ...t) - { return detail::default_eval_expr(expr, static_cast(t)...); } + decltype(auto) evaluate (Expr const & expr, T && ...t); - // TODO: static assert/SFINAE sizeof...(T) >= highest-indexed placeholder + 1 template - decltype(auto) evaluate_as (Expr const & expr, T && ...t) - { return eval_expression_as(expr, hana::basic_type{}, static_cast(t)...); } + decltype(auto) evaluate_as (Expr const & expr, T && ...t); + + template + auto transform (Expr && expr, Transform && transform); template struct expression @@ -56,20 +54,10 @@ namespace boost::proto17 { elements (static_cast(t)...) {} - expression (hana::tuple const & rhs) : - elements (rhs) - {} - expression (hana::tuple && rhs) : elements (std::move(rhs)) {} - expression & operator= (hana::tuple const & rhs) - { elements = rhs.elements; } - - expression & operator= (hana::tuple && rhs) - { elements = std::move(rhs.elements); } - tuple_type elements; #ifdef BOOST_PROTO17_CONVERSION_OPERATOR_TEMPLATE @@ -228,7 +216,7 @@ namespace boost::proto17 { namespace detail { template >::value> + bool Expr = detail::is_expr>::value> struct binary_op_result { using lhs_type = typename detail::operand_type::type; @@ -285,6 +273,14 @@ namespace boost::proto17 { #undef BOOST_PROTO17_BINARY_NON_MEMBER_OPERATOR + template + auto make_expression (T &&... t) + { + return expression{ + hana::tuple{static_cast(t)...} + }; + } + template auto make_terminal (T && t) { @@ -297,4 +293,35 @@ namespace boost::proto17 { #include "detail/default_eval.hpp" +namespace boost::proto17 { + + // TODO: static assert/SFINAE sizeof...(T) >= highest-indexed placeholder + 1 + template + decltype(auto) evaluate (Expr const & expr, T && ...t) + { return detail::default_eval_expr(expr, static_cast(t)...); } + + // TODO: static assert/SFINAE sizeof...(T) >= highest-indexed placeholder + 1 + template + decltype(auto) evaluate_as (Expr const & expr, T && ...t) + { return eval_expression_as(expr, hana::basic_type{}, static_cast(t)...); } + + template + auto transform (Expr && expr, Transform && transform) + { + return detail::default_transform_expression{}( + static_cast(expr), + static_cast(transform) + ); + } + + namespace adl_detail { + + template + constexpr decltype(auto) eval_expression_as (E const & expr, hana::basic_type, T &&... args) + { return static_cast(detail::default_eval_expr(expr, static_cast(args)...)); } + + } + +} + #endif diff --git a/expression_fwd.hpp b/expression_fwd.hpp index 8324f0f..24d72e9 100644 --- a/expression_fwd.hpp +++ b/expression_fwd.hpp @@ -80,13 +80,6 @@ namespace boost::proto17 { } - namespace detail { - - template - decltype(auto) default_eval_expr (Expr const & expr, T &&... args); - - } - } #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0d21ec6..4e34343 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -23,6 +23,7 @@ add_test_executable(user_eval_expression_as) add_test_executable(user_operator_and_eval_expression_as) add_test_executable(user_expression_transform) add_test_executable(user_expression_transform_2) +add_test_executable(user_expression_transform_3) add_test_executable(placeholder_eval) add_test_executable(call_expr) add_test_executable(reference_returns) diff --git a/test/user_expression_transform_3.cpp b/test/user_expression_transform_3.cpp new file mode 100644 index 0000000..c6250b9 --- /dev/null +++ b/test/user_expression_transform_3.cpp @@ -0,0 +1,213 @@ +#include "expression.hpp" + +#include + +#include + + +template +using term = boost::proto17::terminal; + +namespace bp17 = boost::proto17; + + +namespace user { + + struct number + { + double value; + + friend number operator+ (number lhs, number rhs) + { return number{lhs.value + rhs.value}; } + + friend number operator- (number lhs, number rhs) + { return number{lhs.value - rhs.value}; } + + friend number operator* (number lhs, number rhs) + { return number{lhs.value * rhs.value}; } + + friend std::ostream & operator<< (std::ostream & os, number x) // TODO + { return os << x.value; } + }; + + number naxpy (number a, number x, number y) + { return number{a.value * x.value + y.value + 10.0}; } + + struct eval_xform + { + decltype(auto) operator() (term const & expr) + { return expr.value(); } + }; + + struct plus_to_minus_xform + { + template + decltype(auto) operator() (bp17::expression const & expr) + { return bp17::make_expression(expr.left(), expr.right()); } + }; + + decltype(auto) naxpy_eager_nontemplate_xform ( + bp17::expression< + bp17::expr_kind::plus, + bp17::expression< + bp17::expr_kind::multiplies, + term, + term + >, + term + > const & expr + ) { + auto a = evaluate(expr.left().left()); + auto x = evaluate(expr.left().right()); + auto y = evaluate(expr.right()); + return bp17::make_terminal(naxpy(a, x, y)); + } + + decltype(auto) naxpy_lazy_nontemplate_xform ( + bp17::expression< + bp17::expr_kind::plus, + bp17::expression< + bp17::expr_kind::multiplies, + term, + term + >, + term + > const & expr + ) { + auto a = expr.left().left(); + auto x = expr.left().right(); + auto y = expr.right(); + return bp17::make_terminal(naxpy)(a, x, y); + } + + struct naxpy_xform + { + template + decltype(auto) operator() ( + bp17::expression< + bp17::expr_kind::plus, + bp17::expression< + bp17::expr_kind::multiplies, + Expr1, + Expr2 + >, + Expr3 + > const & expr + ) { + auto a = transform(expr.left().left(), naxpy_xform{}); + auto x = transform(expr.left().right(), naxpy_xform{}); + auto y = transform(expr.right(), naxpy_xform{}); + return bp17::make_terminal(naxpy)(a, x, y); + } + }; + +} + +// TODO: Add a unit test for moved expressions and expressions containing +// move-only types or rvalue refs. + +// TODO: Use move-only type in other tests that test moves as well. + +TEST(user_expression_transform_3, test_user_expression_transform_3) +{ + term a{{1.0}}; + term x{{42.0}}; + term y{{3.0}}; + + { + auto expr = a; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 1); + } + + auto transformed_expr = transform(expr, user::eval_xform{}); + { + EXPECT_EQ(transformed_expr.value, 1); + } + } + + { + auto expr = x + y; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 45); + } + + auto transformed_expr = transform(expr, user::plus_to_minus_xform{}); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 39); + } + } + + { + auto expr = a * x + y; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 45); + } + + auto transformed_expr = transform(expr, user::naxpy_eager_nontemplate_xform); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 55); + } + } + + { + auto expr = a + (a * x + y); + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 46); + } + + auto transformed_expr = transform(expr, user::naxpy_eager_nontemplate_xform); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 56); + } + } + + { + auto expr = a * x + y; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 45); + } + + auto transformed_expr = transform(expr, user::naxpy_lazy_nontemplate_xform); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 55); + } + } + + { + auto expr = a + (a * x + y); + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 46); + } + + auto transformed_expr = transform(expr, user::naxpy_lazy_nontemplate_xform); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 56); + } + } + + { + auto expr = (a * x + y) * (a * x + y) + (a * x + y); + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 45 * 45 + 45); + } + + auto transformed_expr = transform(expr, user::naxpy_xform{}); + { + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 55 * 55 + 55 + 10); + } + } +}