diff --git a/boost/yap/detail/default_eval.hpp b/boost/yap/detail/default_eval.hpp index fbcdb74..6bfa1ee 100644 --- a/boost/yap/detail/default_eval.hpp +++ b/boost/yap/detail/default_eval.hpp @@ -234,7 +234,7 @@ namespace boost::yap { struct default_transform_expression< Expr, Transform, - expr_arity::two, // TODO: Add an overload for ::three! + expr_arity::two, std::void_t()( detail::tag_for::kind>(), @@ -254,6 +254,32 @@ namespace boost::yap { } }; + template + struct default_transform_expression< + Expr, + Transform, + expr_arity::three, + std::void_t()( + detail::tag_for::kind>(), + ::boost::yap::value(::boost::yap::cond(std::declval())), + ::boost::yap::value(::boost::yap::then(std::declval())), + ::boost::yap::value(::boost::yap::else_(std::declval())) + ) + )> + > + { + decltype(auto) operator() (Expr && expr, Transform && transform) + { + return static_cast(transform)( + detail::tag_for::kind>(), + ::boost::yap::value(::boost::yap::cond(static_cast(expr))), + ::boost::yap::value(::boost::yap::then(static_cast(expr))), + ::boost::yap::value(::boost::yap::else_(static_cast(expr))) + ); + } + }; + template struct transform_call_unpacker { diff --git a/boost/yap/expression.hpp b/boost/yap/expression.hpp index 7cf0b93..0545391 100644 --- a/boost/yap/expression.hpp +++ b/boost/yap/expression.hpp @@ -324,9 +324,10 @@ namespace boost::yap { if constexpr (detail::is_expr::value) { using namespace hana::literals; constexpr expr_kind kind = detail::remove_cv_ref_t::kind; + constexpr detail::expr_arity arity = detail::arity_of(); if constexpr (kind == expr_kind::expr_ref) { return ::boost::yap::value(::boost::yap::deref(static_cast(x))); - } else if constexpr (kind == expr_kind::terminal || kind == expr_kind::placeholder) { + } else if constexpr (arity == detail::expr_arity::one) { if constexpr (std::is_lvalue_reference{}) { return x.elements[0_c]; } else { @@ -340,69 +341,112 @@ namespace boost::yap { } } - template - decltype(auto) left (Expr && expr) + template + decltype(auto) get (Expr && expr, hana::llong i) { static_assert( detail::is_expr::value, - "left() is only defined for expressions." + "get() is only defined for expressions." ); using namespace hana::literals; constexpr expr_kind kind = detail::remove_cv_ref_t::kind; if constexpr (kind == expr_kind::expr_ref) { - return ::boost::yap::left(::boost::yap::deref(static_cast(expr))); + return ::boost::yap::get(::boost::yap::deref(static_cast(expr)), i); } else { + static_assert( + 0 <= I && I < decltype(hana::size(expr.elements))::value, + "In get(expr, I), I must be nonnegative, and less " + "than hana::size(expr.elements)." + ); + if constexpr (std::is_lvalue_reference{}) { + return expr.elements[i]; + } else { + return std::move(expr.elements[i]); + } + } + } + + template + decltype(auto) get_c (Expr && expr) + { return ::boost::yap::get(static_cast(expr), hana::llong_c); } + + template + decltype(auto) left (Expr && expr) + { + using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), 0_c); + constexpr expr_kind kind = detail::remove_cv_ref_t::kind; + if constexpr (kind != expr_kind::expr_ref) { static_assert( detail::arity_of() == detail::expr_arity::two, "left() is only defined for binary expressions." ); - if constexpr (std::is_lvalue_reference{}) { - return expr.elements[0_c]; - } else { - return std::move(expr.elements[0_c]); - } } } template decltype(auto) right (Expr && expr) { - static_assert( - detail::is_expr::value, - "right() is only defined for expressions." - ); - using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), 1_c); constexpr expr_kind kind = detail::remove_cv_ref_t::kind; - if constexpr (kind == expr_kind::expr_ref) { - return ::boost::yap::right(::boost::yap::deref(static_cast(expr))); - } else { + if constexpr (kind != expr_kind::expr_ref) { static_assert( detail::arity_of() == detail::expr_arity::two, "right() is only defined for binary expressions." ); - if constexpr (std::is_lvalue_reference{}) { - return expr.elements[1_c]; - } else { - return std::move(expr.elements[1_c]); - } + } + } + + template + decltype(auto) cond (Expr && expr) + { + using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), 0_c); + constexpr expr_kind kind = detail::remove_cv_ref_t::kind; + if constexpr (kind != expr_kind::expr_ref) { + static_assert( + kind == expr_kind::if_else, + "cond() is only defined for if_else expressions." + ); + } + } + + template + decltype(auto) then (Expr && expr) + { + using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), 1_c); + constexpr expr_kind kind = detail::remove_cv_ref_t::kind; + if constexpr (kind != expr_kind::expr_ref) { + static_assert( + kind == expr_kind::if_else, + "then() is only defined for if_else expressions." + ); + } + } + + template + decltype(auto) else_ (Expr && expr) + { + using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), 2_c); + constexpr expr_kind kind = detail::remove_cv_ref_t::kind; + if constexpr (kind != expr_kind::expr_ref) { + static_assert( + kind == expr_kind::if_else, + "else_() is only defined for if_else expressions." + ); } } template decltype(auto) argument (Expr && expr, hana::llong i) { - static_assert( - detail::is_expr::value, - "argument() is only defined for expressions." - ); - - using namespace hana::literals; + return ::boost::yap::get(static_cast(expr), i); constexpr expr_kind kind = detail::remove_cv_ref_t::kind; - if constexpr (kind == expr_kind::expr_ref) { - return ::boost::yap::argument(::boost::yap::deref(static_cast(expr)), i); - } else { + if constexpr (kind != expr_kind::expr_ref) { static_assert( detail::arity_of() == detail::expr_arity::n, "argument() is only defined for call expressions." @@ -412,11 +456,6 @@ namespace boost::yap { "In argument(expr, I), I must be nonnegative, and less " "than hana::size(expr.elements)." ); - if constexpr (std::is_lvalue_reference{}) { - return expr.elements[i]; - } else { - return std::move(expr.elements[i]); - } } } diff --git a/test/user_expression_transform_3.cpp b/test/user_expression_transform_3.cpp index 9e5bd19..b09a79a 100644 --- a/test/user_expression_transform_3.cpp +++ b/test/user_expression_transform_3.cpp @@ -24,6 +24,15 @@ namespace user { friend number operator* (number lhs, number rhs) { return number{lhs.value * rhs.value}; } + + friend number operator- (number n) + { return number{-n.value}; } + + friend bool operator< (number lhs, double rhs) + { return lhs.value < rhs; } + + friend bool operator< (double lhs, number rhs) + { return lhs < rhs.value; } }; number naxpy (number a, number x, number y) @@ -167,6 +176,74 @@ namespace user { } }; + + // unary transforms + + struct disable_negate_xform_tag + { + decltype(auto) operator() (yap::negate_tag, user::number const & value) + { return yap::make_terminal(value); } + + template + decltype(auto) operator() (yap::negate_tag, Expr const & expr) + { return expr; } + }; + + struct disable_negate_xform_expr + { + template + decltype(auto) operator() (yap::expression> const & expr) + { return ::boost::yap::value(expr); } + }; + + struct disable_negate_xform_both + { + decltype(auto) operator() (yap::negate_tag, user::number const & value) + { return yap::make_terminal(value); } + + template + decltype(auto) operator() (yap::negate_tag, Expr const & expr) + { return expr; } + + template + decltype(auto) operator() (yap::expression> const & expr) + { + throw std::logic_error("Oops! Picked the wrong overload!"); + return ::boost::yap::value(expr); + } + }; + + + // ternary transforms + + struct ternary_to_else_xform_tag + { + template + decltype(auto) operator() (yap::if_else_tag, Expr const & cond, user::number const & then, user::number const & else_) + { return yap::make_terminal(else_); } + }; + + struct ternary_to_else_xform_expr + { + template + decltype(auto) operator() (yap::expression> const & expr) + { return ::boost::yap::else_(expr); } + }; + + struct ternary_to_else_xform_both + { + template + decltype(auto) operator() (yap::if_else_tag, Expr const & cond, user::number const & then, user::number const & else_) + { return yap::make_terminal(else_); } + + template + decltype(auto) operator() (yap::expression> const & expr) + { + throw std::logic_error("Oops! Picked the wrong overload!"); + return ::boost::yap::else_(expr); + } + }; + } TEST(user_expression_transform_3, test_user_expression_transform_3) @@ -341,3 +418,147 @@ TEST(move_only, test_user_expression_transform_3) transform(std::move(transformed_expr), check_unique_ptrs_equal_7); } + +TEST(unary_transforms, test_user_expression_transform_3) +{ + term a{{1.0}}; + term x{{42.0}}; + term y{{3.0}}; + + { + auto expr = -x; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, -42); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 42); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 42); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_both{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 42); + } + } + + { + auto expr = a * -x + y; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, -39); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_both{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + } + + { + auto expr = -(x + y); + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, -45); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + + { + auto transformed_expr = transform(expr, user::disable_negate_xform_both{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 45); + } + } +} + +TEST(ternary_transforms, test_user_expression_transform_3) +{ + term a{{1.0}}; + term x{{42.0}}; + term y{{3.0}}; + + { + auto expr = if_else(0 < a, x, y); + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 42); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 3); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 3); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_both{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 3); + } + } + + { + auto expr = y * if_else(0 < a, x, y) + user::number{0.0}; + { + user::number result = evaluate(expr); + EXPECT_EQ(result.value, 126); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_tag{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 9); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_expr{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 9); + } + + { + auto transformed_expr = transform(expr, user::ternary_to_else_xform_both{}); + user::number result = evaluate(transformed_expr); + EXPECT_EQ(result.value, 9); + } + } +}