#include "expression.hpp" #include template using term = boost::proto17::terminal; namespace bp17 = boost::proto17; namespace bh = boost::hana; namespace user { struct number { double value; friend number operator+ (number lhs, number rhs) { return number{lhs.value + rhs.value}; } friend number operator- (number lhs, number rhs) { return number{lhs.value - rhs.value}; } friend number operator* (number lhs, number rhs) { return number{lhs.value * rhs.value}; } }; number naxpy (number a, number x, number y) { return number{a.value * x.value + y.value + 10.0}; } struct empty_xform {}; struct eval_xform_tag { decltype(auto) operator() (bp17::terminal_tag, user::number const & n) { return n; } }; struct eval_xform_expr { decltype(auto) operator() (term const & expr) { return ::boost::proto17::value(expr); } }; struct eval_xform_both { decltype(auto) operator() (bp17::terminal_tag, user::number const & n) { return n; } decltype(auto) operator() (term const & expr) { throw std::logic_error("Oops! Picked the wrong overload!"); return ::boost::proto17::value(expr); } }; struct plus_to_minus_xform_tag { decltype(auto) operator() (bp17::plus_tag, user::number const & lhs, user::number const & rhs) { return bp17::make_expression( term{lhs}, term{rhs} ); } }; struct plus_to_minus_xform_expr { template decltype(auto) operator() (bp17::expression> const & expr) { return bp17::make_expression( ::boost::proto17::left(expr), ::boost::proto17::right(expr) ); } }; struct plus_to_minus_xform_both { decltype(auto) operator() (bp17::plus_tag, user::number const & lhs, user::number const & rhs) { return bp17::make_expression( term{lhs}, term{rhs} ); } template decltype(auto) operator() (bp17::expression> const & expr) { throw std::logic_error("Oops! Picked the wrong overload!"); return bp17::make_expression( ::boost::proto17::left(expr), ::boost::proto17::right(expr) ); } }; decltype(auto) naxpy_eager_nontemplate_xform ( bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression< bp17::expr_kind::multiplies, bh::tuple< bp17::expression_ref &>, bp17::expression_ref &> > >, bp17::expression_ref &> > > const & expr ) { auto a = evaluate(expr.left().left()); auto x = evaluate(expr.left().right()); auto y = evaluate(expr.right()); return bp17::make_terminal(naxpy(a, x, y)); } decltype(auto) naxpy_lazy_nontemplate_xform ( bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression< bp17::expr_kind::multiplies, bh::tuple< bp17::expression_ref &>, bp17::expression_ref &> > >, bp17::expression_ref &> > > const & expr ) { decltype(auto) a = expr.left().left().value(); decltype(auto) x = expr.left().right().value(); decltype(auto) y = expr.right().value(); return bp17::make_terminal(naxpy)(a, x, y); } struct naxpy_xform { template decltype(auto) operator() ( bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression< bp17::expr_kind::multiplies, bh::tuple< Expr1, Expr2 > >, Expr3 > > const & expr ) { return bp17::make_terminal(naxpy)( transform(expr.left().left(), naxpy_xform{}), transform(expr.left().right(), naxpy_xform{}), transform(expr.right(), naxpy_xform{}) ); } }; } TEST(user_expression_transform_3, test_user_expression_transform_3) { term a{{1.0}}; term x{{42.0}}; term y{{3.0}}; { auto expr = a; { user::number result = evaluate(expr); EXPECT_EQ(result.value, 1); } { auto transformed_expr = transform(expr, user::empty_xform{}); user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 1); } { auto transformed_expr = transform(expr, user::eval_xform_tag{}); EXPECT_EQ(transformed_expr.value, 1); } { auto transformed_expr = transform(expr, user::eval_xform_expr{}); EXPECT_EQ(transformed_expr.value, 1); } { auto transformed_expr = transform(expr, user::eval_xform_both{}); EXPECT_EQ(transformed_expr.value, 1); } } { auto expr = x + y; { user::number result = evaluate(expr); EXPECT_EQ(result.value, 45); } { auto transformed_expr = transform(expr, user::plus_to_minus_xform_tag{}); user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 39); } { auto transformed_expr = transform(expr, user::plus_to_minus_xform_expr{}); user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 39); } { auto transformed_expr = transform(expr, user::plus_to_minus_xform_both{}); user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 39); } } { auto expr = a * x + y; { user::number result = evaluate(expr); EXPECT_EQ(result.value, 45); } auto transformed_expr = transform(expr, user::naxpy_eager_nontemplate_xform); { user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 55); } } { auto expr = a + (a * x + y); { user::number result = evaluate(expr); EXPECT_EQ(result.value, 46); } auto transformed_expr = transform(expr, user::naxpy_eager_nontemplate_xform); { user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 56); } } { auto expr = a * x + y; { user::number result = evaluate(expr); EXPECT_EQ(result.value, 45); } auto transformed_expr = transform(expr, user::naxpy_lazy_nontemplate_xform); { user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 55); } } { auto expr = a + (a * x + y); { user::number result = evaluate(expr); EXPECT_EQ(result.value, 46); } auto transformed_expr = transform(expr, user::naxpy_lazy_nontemplate_xform); { user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 56); } } { auto expr = (a * x + y) * (a * x + y) + (a * x + y); { user::number result = evaluate(expr); EXPECT_EQ(result.value, 45 * 45 + 45); } auto transformed_expr = transform(expr, user::naxpy_xform{}); { user::number result = evaluate(transformed_expr); EXPECT_EQ(result.value, 55 * 55 + 55 + 10); } } } auto double_to_float (term expr) { return term{(float)expr.value()}; } auto check_unique_ptrs_equal_7 (term> && expr) { using namespace boost::hana::literals; EXPECT_EQ(*expr.elements[0_c], 7); return std::move(expr); } TEST(move_only, test_user_expression_transform_3) { term unity{1.0}; term> i{new int{7}}; bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression_ref &>, term> > > expr_1 = unity + std::move(i); bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression_ref &>, bp17::expression< bp17::expr_kind::plus, bh::tuple< bp17::expression_ref &>, term> > > > > expr_2 = unity + std::move(expr_1); auto transformed_expr = transform(std::move(expr_2), double_to_float); transform(std::move(transformed_expr), check_unique_ptrs_equal_7); }