From b53cfc25c9e841c41fa7000ea9fa5361eb211427 Mon Sep 17 00:00:00 2001 From: mzhelyez Date: Thu, 21 Aug 2025 11:26:32 +0200 Subject: [PATCH] modified template variable names, changed copies to references --- .../math/differentiation/autodiff_reverse.hpp | 368 +++++----- ...mode_autodiff_basic_operator_overloads.hpp | 320 ++++---- ...autodiff_comparison_operator_overloads.hpp | 176 ++--- .../reverse_mode_autodiff_erf_overloads.hpp | 143 ++-- ...mode_autodiff_expression_template_base.hpp | 158 ++-- ...everse_mode_autodiff_memory_management.hpp | 62 +- .../reverse_mode_autodiff_stl_overloads.hpp | 684 ++++++++++-------- 7 files changed, 1049 insertions(+), 862 deletions(-) diff --git a/include/boost/math/differentiation/autodiff_reverse.hpp b/include/boost/math/differentiation/autodiff_reverse.hpp index 62405764f..b64988104 100644 --- a/include/boost/math/differentiation/autodiff_reverse.hpp +++ b/include/boost/math/differentiation/autodiff_reverse.hpp @@ -35,40 +35,46 @@ namespace differentiation { namespace reverse_mode { /* forward declarations for utitlity functions */ -template +template struct expression; -template +template class rvar; -template +template struct abstract_binary_expression; -template +template struct abstract_unary_expression; -template +template class gradient_node; // forward declaration for tape // manages nodes in computational graph -template +template class gradient_tape { /** @brief tape (graph) management class for autodiff * holds all the data structures for autodiff */ private: /* type decays to order - 1 to support higher order derivatives */ - using inner_t = rvar_t; + using inner_t = rvar_t; /* adjoints are the overall derivative, and derivatives are the "local" * derivative */ detail::flat_linear_allocator adjoints_; detail::flat_linear_allocator derivatives_; - detail::flat_linear_allocator, buffer_size> gradient_nodes_; - detail::flat_linear_allocator *, buffer_size> argument_nodes_; + detail::flat_linear_allocator, buffer_size> + gradient_nodes_; + detail::flat_linear_allocator *, buffer_size> + argument_nodes_; // compile time check if emplace_back calls on zero template - gradient_node *fill_node_at_compile_time(std::true_type, - gradient_node *node_ptr) + gradient_node *fill_node_at_compile_time( + std::true_type, gradient_node *node_ptr) { node_ptr->derivatives_ = derivatives_.template emplace_back_n(); node_ptr->argument_nodes_ = argument_nodes_.template emplace_back_n(); @@ -76,8 +82,8 @@ private: } template - gradient_node *fill_node_at_compile_time(std::false_type, - gradient_node *node_ptr) + gradient_node *fill_node_at_compile_time( + std::false_type, gradient_node *node_ptr) { node_ptr->derivatives_ = nullptr; node_ptr->argument_adjoints_ = nullptr; @@ -94,9 +100,11 @@ public: using derivatives_iterator = typename detail::flat_linear_allocator::iterator; using gradient_nodes_iterator = - typename detail::flat_linear_allocator, buffer_size>::iterator; + typename detail::flat_linear_allocator, + buffer_size>::iterator; using argument_nodes_iterator = - typename detail::flat_linear_allocator *, buffer_size>::iterator; + typename detail::flat_linear_allocator *, + buffer_size>::iterator; gradient_tape() { clear(); }; @@ -114,9 +122,9 @@ public: } // no derivatives or arguments - gradient_node *emplace_leaf_node() + gradient_node *emplace_leaf_node() { - gradient_node *node = &*gradient_nodes_.emplace_back(); + gradient_node *node = &*gradient_nodes_.emplace_back(); node->adjoint_ = adjoints_.emplace_back(); node->derivatives_ = derivatives_iterator(); // nullptr; node->argument_nodes_ = argument_nodes_iterator(); // nullptr; @@ -125,9 +133,9 @@ public: }; // single argument, single derivative - gradient_node *emplace_active_unary_node() + gradient_node *emplace_active_unary_node() { - gradient_node *node = &*gradient_nodes_.emplace_back(); + gradient_node *node = &*gradient_nodes_.emplace_back(); node->n_ = 1; node->adjoint_ = adjoints_.emplace_back(); node->derivatives_ = derivatives_.emplace_back(); @@ -137,9 +145,9 @@ public: // arbitrary number of arguments/derivatives (compile time) template - gradient_node *emplace_active_multi_node() + gradient_node *emplace_active_multi_node() { - gradient_node *node = &*gradient_nodes_.emplace_back(); + gradient_node *node = &*gradient_nodes_.emplace_back(); node->n_ = n; node->adjoint_ = adjoints_.emplace_back(); // emulate if constexpr @@ -147,9 +155,9 @@ public: } // same as above at runtime - gradient_node *emplace_active_multi_node(size_t n) + gradient_node *emplace_active_multi_node(size_t n) { - gradient_node *node = &*gradient_nodes_.emplace_back(); + gradient_node *node = &*gradient_nodes_.emplace_back(); node->n_ = n; node->adjoint_ = adjoints_.emplace_back(); if (n > 0) { @@ -161,14 +169,17 @@ public: /* manual reset button for all adjoints */ void zero_grad() { - const T zero = T(0.0); + const RealType zero = RealType(0.0); adjoints_.fill(zero); } // return type is an iterator auto begin() { return gradient_nodes_.begin(); } auto end() { return gradient_nodes_.end(); } - auto find(gradient_node *node) { return gradient_nodes_.find(node); }; + auto find(gradient_node *node) + { + return gradient_nodes_.find(node); + }; void add_checkpoint() { gradient_nodes_.add_checkpoint(); @@ -206,11 +217,14 @@ public: } // random acces - gradient_node &operator[](size_t i) { return gradient_nodes_[i]; } - const gradient_node &operator[](size_t i) const { return gradient_nodes_[i]; } + gradient_node &operator[](size_t i) { return gradient_nodes_[i]; } + const gradient_node &operator[](size_t i) const + { + return gradient_nodes_[i]; + } }; // class rvar; -template // no CRTP, just storage +template // no CRTP, just storage class gradient_node { /* @@ -218,13 +232,15 @@ class gradient_node * adjoints pointers to arguments aren't needed here * */ public: - using adjoint_iterator = typename gradient_tape::adjoint_iterator; - using derivatives_iterator = typename gradient_tape::derivatives_iterator; - using argument_nodes_iterator = typename gradient_tape::argument_nodes_iterator; + using adjoint_iterator = typename gradient_tape::adjoint_iterator; + using derivatives_iterator = + typename gradient_tape::derivatives_iterator; + using argument_nodes_iterator = + typename gradient_tape::argument_nodes_iterator; private: size_t n_; - using inner_t = rvar_t; + using inner_t = rvar_t; /* these are iterators in case * flat linear allocator is at capacity, and needs to allocate a new block of * memory. */ @@ -233,8 +249,8 @@ private: argument_nodes_iterator argument_nodes_; public: - friend class gradient_tape; - friend class rvar; + friend class gradient_tape; + friend class rvar; gradient_node() = default; explicit gradient_node(const size_t n) @@ -242,7 +258,10 @@ public: , adjoint_(nullptr) , derivatives_(nullptr) {} - explicit gradient_node(const size_t n, T *adjoint, T *derivatives, rvar **arguments) + explicit gradient_node(const size_t n, + RealType *adjoint, + RealType *derivatives, + rvar **arguments) : n_(n) , adjoint_(adjoint) , derivatives_(derivatives) @@ -263,7 +282,7 @@ public: { argument_nodes_[static_cast(arg_id)]->update_adjoint_v(value); }; - void update_argument_ptr_at(size_t arg_id, gradient_node *node_ptr) + void update_argument_ptr_at(size_t arg_id, gradient_node *node_ptr) { argument_nodes_[static_cast(arg_id)] = node_ptr; } @@ -275,7 +294,7 @@ public: using boost::math::differentiation::reverse_mode::fabs; using std::fabs; - if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits::epsilon()) + if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits::epsilon()) return; if (!argument_nodes_) // no arguments @@ -294,21 +313,22 @@ public: }; /****************************************************************************************************************/ -template -inline gradient_tape &get_active_tape() +template +inline gradient_tape &get_active_tape() { - static BOOST_MATH_THREAD_LOCAL gradient_tape tape; + static BOOST_MATH_THREAD_LOCAL gradient_tape + tape; return tape; } -template -class rvar : public expression> +template +class rvar : public expression> { private: - using inner_t = rvar_t; - friend class gradient_node; + using inner_t = rvar_t; + friend class gradient_node; inner_t value_; - gradient_node *node_ = nullptr; + gradient_node *node_ = nullptr; template friend class rvar; /*****************************************************************************************/ @@ -322,13 +342,13 @@ private: /** @return value_ at rvar_t */ - static auto &get(rvar &v) + static auto &get(rvar &v) { return get_value_at_impl::get(v.get_value()); } /** @return const value_ at rvar_t */ - static const auto &get(const rvar &v) + static const auto &get(const rvar &v) { return get_value_at_impl::get(v.get_value()); } @@ -341,65 +361,69 @@ private: { /** @return value_ at rvar_t */ - static auto &get(rvar &v) { return v; } + static auto &get(rvar &v) { return v; } /** @return const value_ at rvar_t */ - static const auto &get(const rvar &v) { return v; } + static const auto &get(const rvar &v) { return v; } }; /*****************************************************************************************/ void make_leaf_node() { - gradient_tape &tape = get_active_tape(); + gradient_tape &tape + = get_active_tape(); node_ = tape.emplace_leaf_node(); } void make_unary_node() { - gradient_tape &tape = get_active_tape(); + gradient_tape &tape + = get_active_tape(); node_ = tape.emplace_active_unary_node(); } void make_multi_node(size_t n) { - gradient_tape &tape = get_active_tape(); + gradient_tape &tape + = get_active_tape(); node_ = tape.emplace_active_multi_node(n); } template void make_multi_node() { - gradient_tape &tape = get_active_tape(); + gradient_tape &tape + = get_active_tape(); node_ = tape.template emplace_active_multi_node(); } template - void make_rvar_from_expr(const expression &expr) + void make_rvar_from_expr(const expression &expr) { - make_multi_node>(); + make_multi_node>(); expr.template propagatex<0>(node_, inner_t(1.0)); } - T get_item_impl(std::true_type) const + RealType get_item_impl(std::true_type) const { - return value_.get_item_impl(std::integral_constant 1)>{}); + return value_.get_item_impl(std::integral_constant 1)>{}); } - T get_item_impl(std::false_type) const { return value_; } + RealType get_item_impl(std::false_type) const { return value_; } public: - using value_type = T; - static constexpr size_t order_v = order; + using value_type = RealType; + static constexpr size_t DerivativeOrder_v = DerivativeOrder; rvar() : value_() { make_leaf_node(); } - rvar(const T value) + rvar(const RealType value) : value_(inner_t{value}) { make_leaf_node(); } - rvar &operator=(T v) + rvar &operator=(RealType v) { value_ = inner_t(v); if (node_ == nullptr) { @@ -407,24 +431,24 @@ public: } return *this; } - rvar(const rvar &other) = default; - rvar &operator=(const rvar &other) = default; + rvar(const rvar &other) = default; + rvar &operator=(const rvar &other) = default; template - void propagatex(gradient_node *node, inner_t adj) const + void propagatex(gradient_node *node, inner_t adj) const { node->update_derivative_v(arg_index, adj); node->update_argument_ptr_at(arg_index, node_); } template - rvar(const expression &expr) + rvar(const expression &expr) { value_ = expr.evaluate(); make_rvar_from_expr(expr); } template - rvar &operator=(const expression &expr) + rvar &operator=(const expression &expr) { value_ = expr.evaluate(); make_rvar_from_expr(expr); @@ -432,52 +456,52 @@ public: } /***************************************************************************************************/ template - rvar &operator+=(const expression &expr) + rvar &operator+=(const expression &expr) { *this = *this + expr; return *this; } template - rvar &operator*=(const expression &expr) + rvar &operator*=(const expression &expr) { *this = *this * expr; return *this; } template - rvar &operator-=(const expression &expr) + rvar &operator-=(const expression &expr) { *this = *this - expr; return *this; } template - rvar &operator/=(const expression &expr) + rvar &operator/=(const expression &expr) { *this = *this / expr; return *this; } /***************************************************************************************************/ - rvar &operator+=(const T &v) + rvar &operator+=(const RealType &v) { *this = *this + v; return *this; } - rvar &operator*=(const T &v) + rvar &operator*=(const RealType &v) { *this = *this * v; return *this; } - rvar &operator-=(const T &v) + rvar &operator-=(const RealType &v) { *this = *this - v; return *this; } - rvar &operator/=(const T &v) + rvar &operator/=(const RealType &v) { *this = *this / v; return *this; @@ -490,7 +514,7 @@ public: const inner_t &evaluate() const { return value_; }; inner_t &get_value() { return value_; }; - explicit operator T() const { return item(); } + explicit operator RealType() const { return item(); } explicit operator int() const { return static_cast(item()); } explicit operator long() const { return static_cast(item()); } @@ -503,23 +527,27 @@ public: template auto &get_value_at() { - static_assert(N <= order, "Requested depth exceeds variable order."); - return get_value_at_impl::get(*this); + static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order."); + return get_value_at_impl::get(*this); } /** @brief same as above but const */ template const auto &get_value_at() const { - static_assert(N <= order, "Requested depth exceeds variable order."); - return get_value_at_impl::get(*this); + static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order."); + return get_value_at_impl::get(*this); } - T item() const { return get_item_impl(std::integral_constant 1)>{}); } + RealType item() const + { + return get_item_impl(std::integral_constant 1)>{}); + } void backward() { - gradient_tape &tape = get_active_tape(); + gradient_tape &tape + = get_active_tape(); auto it = tape.find(node_); it->update_adjoint_v(inner_t(1.0)); while (it != tape.begin()) { @@ -530,32 +558,32 @@ public: } }; -template -std::ostream &operator<<(std::ostream &os, const rvar var) +template +std::ostream &operator<<(std::ostream &os, const rvar var) { - os << "rvar<" << order << ">(" << var.item() << "," << var.adjoint() << ")"; + os << "rvar<" << DerivativeOrder << ">(" << var.item() << "," << var.adjoint() << ")"; return os; } -template -std::ostream &operator<<(std::ostream &os, const expression &expr) +template +std::ostream &operator<<(std::ostream &os, const expression &expr) { - rvar tmp = expr; - os << "rvar<" << order << ">(" << tmp.item() << "," << tmp.adjoint() << ")"; + rvar tmp = expr; + os << "rvar<" << DerivativeOrder << ">(" << tmp.item() << "," << tmp.adjoint() << ")"; return os; } -template -rvar make_rvar(const T v) +template +rvar make_rvar(const RealType v) { - static_assert(order > 0, "rvar order must be >= 1"); - return rvar(v); + static_assert(DerivativeOrder > 0, "rvar order must be >= 1"); + return rvar(v); } -template -rvar make_rvar(const expression &expr) +template +rvar make_rvar(const expression &expr) { - static_assert(order > 0, "rvar order must be >= 1"); - return rvar(expr); + static_assert(DerivativeOrder > 0, "rvar order must be >= 1"); + return rvar(expr); } namespace detail { @@ -565,37 +593,24 @@ namespace detail { * specialization for autodiffing through autodiff. i.e. being able to * compute higher order grads */ -template +template struct grad_op_impl { - std::vector> operator()(rvar &f, std::vector *> &x) + std::vector> operator()( + rvar &f, std::vector *> &x) { - auto &tape = get_active_tape(); + auto &tape = get_active_tape(); tape.zero_grad(); f.backward(); - std::vector> gradient_vector; + std::vector> gradient_vector; gradient_vector.reserve(x.size()); - for (auto xi : x) { - // make a new rvar holding the adjoint value + for (auto &xi : x) { gradient_vector.emplace_back(xi->adjoint()); } return gradient_vector; } - /* - std::vector *> operator()(rvar &f, - std::vector *> &x) - { - gradient_tape &tape = get_active_tape(); - tape.zero_grad(); - f.backward(); - std::vector *> gradient_vector; - for (auto xi : x) { - gradient_vector.push_back(&(xi->adjoint())); - } - return gradient_vector; - }*/ }; /** @brief helper overload for grad implementation. * @return vector of gradients of the autodiff graph. @@ -610,7 +625,7 @@ struct grad_op_impl tape.zero_grad(); f.backward(); std::vector gradient_vector; - for (auto xi : x) { + for (auto &xi : x) { gradient_vector.push_back(xi->adjoint()); } return gradient_vector; @@ -621,52 +636,51 @@ struct grad_op_impl * @return nested vector representing N-d tensor of * higher order derivatives */ -template +template struct grad_nd_impl { - auto operator()(rvar &f, std::vector *> &x) + auto operator()(rvar &f, + std::vector *> &x) { static_assert(N > 1, "N must be greater than 1 for this template"); - auto current_grad = grad(f, x); // vector> or vector + auto current_grad = grad(f, x); // vector> or vector - std::vector()(current_grad[0], x))> + std::vector()( + current_grad[0], x))> result; result.reserve(current_grad.size()); for (auto &g_i : current_grad) { - result.push_back(grad_nd_impl()(g_i, x)); + result.push_back( + grad_nd_impl()(g_i, x)); } return result; } - /* - auto operator()(rvar &f, std::vector *> &x) - { - static_assert(N > 1, "N must be greater than 1 for this template"); - auto current_grad = grad(f, x); - std::vector()(*current_grad[0], x))> - result; - for (auto &g_i : current_grad) { - result.push_back(grad_nd_impl()(*g_i, x)); - } - return result; - }*/ }; /** @brief spcialization for order = 1, - * @return vector> gradients */ -template -struct grad_nd_impl<1, T, order_1, order_2> + * @return vector> gradients */ +template +struct grad_nd_impl<1, RealType, DerivativeOrder_1, DerivativeOrder_2> { - auto operator()(rvar &f, std::vector *> &x) { return grad(f, x); } + auto operator()(rvar &f, + std::vector *> &x) + { + return grad(f, x); + } }; template struct rvar_order; -template -struct rvar_order *> +template +struct rvar_order *> { - static constexpr size_t value = order; + static constexpr size_t value = DerivativeOrder; }; } // namespace detail @@ -676,52 +690,52 @@ struct rvar_order *> * @param f -> computational graph * @param x -> variables gradients to record. Note ALL gradients of the graph * are computed simultaneously, only the ones w.r.t. x are returned - * @return vector of gradients. in the case of order_1 = 1 - * rvar decays to T + * @return vector of gradients. in the case of DerivativeOrder_1 = 1 + * rvar decays to T * * safe to call recursively with grad(grad(grad... */ -template -auto grad(rvar &f, std::vector *> &x) +template +auto grad(rvar &f, std::vector *> &x) { - static_assert(order_1 <= order_2, + static_assert(DerivativeOrder_1 <= DerivativeOrder_2, "variable differentiating w.r.t. must have order >= function order"); - std::vector *> xx; + std::vector *> xx; for (auto &xi : x) - xx.push_back(&(xi->template get_value_at())); - return detail::grad_op_impl{}(f, xx); + xx.push_back(&(xi->template get_value_at())); + return detail::grad_op_impl{}(f, xx); } /** @brief variadic overload of above */ -template -auto grad(rvar &f, First first, Other... other) +template +auto grad(rvar &f, First first, Other... other) { - constexpr size_t order_2 = detail::rvar_order::value; - static_assert(order_1 <= order_2, + constexpr size_t DerivativeOrder_2 = detail::rvar_order::value; + static_assert(DerivativeOrder_1 <= DerivativeOrder_2, "variable differentiating w.r.t. must have order >= function order"); - std::vector *> x_vec = {first, other...}; + std::vector *> x_vec = {first, other...}; return grad(f, x_vec); } /** @brief computes hessian matrix of computational graph w.r.t. * vector of variables x. - * @return std::vector> hessian matrix + * @return std::vector> hessian matrix * rvar decays to T * * NOT recursion safe, cannot do hess(hess( */ -template -auto hess(rvar &f, std::vector *> &x) +template +auto hess(rvar &f, std::vector *> &x) { - return detail::grad_nd_impl<2, T, order_1, order_2>{}(f, x); + return detail::grad_nd_impl<2, RealType, DerivativeOrder_1, DerivativeOrder_2>{}(f, x); } /** @brief variadic overload of above */ -template -auto hess(rvar &f, First first, Other... other) +template +auto hess(rvar &f, First first, Other... other) { - constexpr size_t order_2 = detail::rvar_order::value; - std::vector *> x_vec = {first, other...}; + constexpr size_t DerivativeOrder_2 = detail::rvar_order::value; + std::vector *> x_vec = {first, other...}; return hess(f, x_vec); } @@ -731,13 +745,15 @@ auto hess(rvar &f, First first, Other... other) * * NOT recursively safe, cannot do grad_nd(grad_nd(... etc... */ -template -auto grad_nd(rvar &f, std::vector *> &x) +template +auto grad_nd(rvar &f, + std::vector *> &x) { - static_assert(order_1 >= N, "Function order must be at least N"); - static_assert(order_2 >= order_1, "Variable order must be at least function order"); + static_assert(DerivativeOrder_1 >= N, "Function order must be at least N"); + static_assert(DerivativeOrder_2 >= DerivativeOrder_1, + "Variable order must be at least function order"); - return detail::grad_nd_impl()(f, x); + return detail::grad_nd_impl()(f, x); } /** @brief variadic overload of above @@ -745,11 +761,11 @@ auto grad_nd(rvar &f, std::vector *> &x) template auto grad_nd(ftype &f, First first, Other... other) { - using T = typename ftype::value_type; - constexpr size_t order_1 = detail::rvar_order::value; - constexpr size_t order_2 = detail::rvar_order::value; - std::vector *> x_vec = {first, other...}; - return detail::grad_nd_impl{}(f, x_vec); + using RealType = typename ftype::value_type; + constexpr size_t DerivativeOrder_1 = detail::rvar_order::value; + constexpr size_t DerivativeOrder_2 = detail::rvar_order::value; + std::vector *> x_vec = {first, other...}; + return detail::grad_nd_impl{}(f, x_vec); } } // namespace reverse_mode } // namespace differentiation @@ -758,10 +774,10 @@ auto grad_nd(ftype &f, First first, Other... other) namespace std { // copied from forward mode -template -class numeric_limits> - : public numeric_limits< - typename boost::math::differentiation::reverse_mode::rvar::value_type> +template +class numeric_limits> + : public numeric_limits::value_type> {}; } // namespace std #endif diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_basic_operator_overloads.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_basic_operator_overloads.hpp index 5f6a9075d..1acb2439f 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_basic_operator_overloads.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_basic_operator_overloads.hpp @@ -12,19 +12,26 @@ namespace math { namespace differentiation { namespace reverse_mode { /****************************************************************************************************************/ -template -struct add_expr - : public abstract_binary_expression> +template +struct add_expr : public abstract_binary_expression> { /* @brief addition * rvar+rvar * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit add_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( - left_hand_expr, right_hand_expr) + explicit add_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>(left_hand_expr, + right_hand_expr) {} inner_t evaluate() const { return this->lhs.evaluate() + this->rhs.evaluate(); } @@ -41,37 +48,51 @@ struct add_expr return inner_t(1.0); } }; -template +template struct add_const_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /* @brief * rvar+float or float+rvar * */ - using inner_t = rvar_t; - explicit add_const_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + using inner_t = rvar_t; + explicit add_const_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { return this->arg.evaluate() + inner_t(this->constant); } static const inner_t derivative(const inner_t & /*argv*/, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { return inner_t(1.0); } }; /****************************************************************************************************************/ -template -struct mult_expr - : public abstract_binary_expression> +template +struct mult_expr : public abstract_binary_expression> { /* @brief multiplication * rvar * rvar * */ - using inner_t = rvar_t; - explicit mult_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( - left_hand_expr, right_hand_expr) + using inner_t = rvar_t; + explicit mult_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>(left_hand_expr, + right_hand_expr) {} inner_t evaluate() const { return this->lhs.evaluate() * this->rhs.evaluate(); }; @@ -88,40 +109,54 @@ struct mult_expr return l; }; }; -template +template struct mult_const_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /* @brief * rvar+float or float+rvar * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit mult_const_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit mult_const_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { return this->arg.evaluate() * inner_t(this->constant); } static const inner_t derivative(const inner_t & /*argv*/, const inner_t & /*v*/, - const T &constant) + const RealType &constant) { return inner_t(constant); } }; /****************************************************************************************************************/ -template -struct sub_expr - : public abstract_binary_expression> +template +struct sub_expr : public abstract_binary_expression> { /* @brief addition * rvar-rvar * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit sub_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( - left_hand_expr, right_hand_expr) + explicit sub_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>(left_hand_expr, + right_hand_expr) {} inner_t evaluate() const { return this->lhs.evaluate() - this->rhs.evaluate(); } @@ -140,19 +175,26 @@ struct sub_expr }; /****************************************************************************************************************/ -template -struct div_expr - : public abstract_binary_expression> +template +struct div_expr : public abstract_binary_expression> { /* @brief multiplication * rvar / rvar * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit div_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( - left_hand_expr, right_hand_expr) + explicit div_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>(left_hand_expr, + right_hand_expr) {} inner_t evaluate() const { return this->lhs.evaluate() / this->rhs.evaluate(); }; @@ -160,182 +202,212 @@ struct div_expr const inner_t &r, const inner_t & /*v*/) { - return static_cast(1.0) / r; + return static_cast(1.0) / r; }; static const inner_t right_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/) { return -l / (r * r); }; }; -template +template struct div_by_const_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /* @brief * rvar/float * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit div_by_const_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit div_by_const_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, + v){}; inner_t evaluate() const { return this->arg.evaluate() / inner_t(this->constant); } static const inner_t derivative(const inner_t & /*argv*/, const inner_t & /*v*/, - const T &constant) + const RealType &constant) { return inner_t(1.0 / constant); } }; -template +template struct const_div_by_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief * float/rvar * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit const_div_by_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit const_div_by_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, + v){}; inner_t evaluate() const { return inner_t(this->constant) / this->arg.evaluate(); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, + const inner_t & /*v*/, + const RealType &constant) { return -inner_t{constant} / (argv * argv); } }; /****************************************************************************************************************/ -template -mult_expr operator*(const expression &lhs, - const expression &rhs) +template +mult_expr operator*( + const expression &lhs, + const expression &rhs) { - return mult_expr(lhs, rhs); + return mult_expr(lhs, rhs); } /** @brief type promotion is handled by casting the numeric type to * the type inside expression. This is to avoid converting the * entire tape in case you have something like double * rvar * */ -template::value>::type> -mult_const_expr operator*(const expression &arg, const U &v) + typename = typename std::enable_if::value>::type> +mult_const_expr operator*( + const expression &arg, const RealType2 &v) { - return mult_const_expr(arg, static_cast(v)); + return mult_const_expr(arg, static_cast(v)); } -template::value>::type> -mult_const_expr operator*(const U &v, const expression &arg) + typename = typename std::enable_if::value>::type> +mult_const_expr operator*( + const RealType2 &v, const expression &arg) { - return mult_const_expr(arg, static_cast(v)); + return mult_const_expr(arg, static_cast(v)); } /****************************************************************************************************************/ /* + */ -template -add_expr operator+(const expression &lhs, - const expression &rhs) +template +add_expr operator+( + const expression &lhs, + const expression &rhs) { - return add_expr(lhs, rhs); + return add_expr(lhs, rhs); } -template::value>::type> -add_const_expr operator+(const expression &arg, const U &v) + typename = typename std::enable_if::value>::type> +add_const_expr operator+( + const expression &arg, const RealType2 &v) { - return add_const_expr(arg, static_cast(v)); + return add_const_expr(arg, static_cast(v)); } -template::value>::type> -add_const_expr operator+(const U &v, const expression &arg) + typename = typename std::enable_if::value>::type> +add_const_expr operator+( + const RealType2 &v, const expression &arg) { - return add_const_expr(arg, static_cast(v)); + return add_const_expr(arg, static_cast(v)); } /****************************************************************************************************************/ /* - overload */ /** @brief * negation (-1.0*rvar) */ -template -mult_const_expr operator-(const expression &arg) +template +mult_const_expr operator-( + const expression &arg) { - return mult_const_expr(arg, static_cast(-1.0)); + return mult_const_expr(arg, static_cast(-1.0)); } /** @brief * subtraction rvar-rvar */ -template -sub_expr operator-(const expression &lhs, - const expression &rhs) +template +sub_expr operator-( + const expression &lhs, + const expression &rhs) { - return sub_expr(lhs, rhs); + return sub_expr(lhs, rhs); } /** @brief * subtraction float - rvar */ -template::value>::type> -add_const_expr operator-(const expression &arg, const U &v) + typename = typename std::enable_if::value>::type> +add_const_expr operator-( + const expression &arg, const RealType2 &v) { /* rvar - float = rvar + (-float) */ - return add_const_expr(arg, static_cast(-v)); + return add_const_expr(arg, static_cast(-v)); } /** @brief * subtraction float - rvar * @return add_expr> */ -template::value>::type> -auto operator-(const U &v, const expression &arg) + typename = typename std::enable_if::value>::type> +auto operator-(const RealType2 &v, const expression &arg) { auto neg = -arg; - return neg + static_cast(v); + return neg + static_cast(v); } /****************************************************************************************************************/ /* / */ -template -div_expr operator/(const expression &lhs, - const expression &rhs) +template +div_expr operator/( + const expression &lhs, + const expression &rhs) { - return div_expr(lhs, rhs); + return div_expr(lhs, rhs); } -template::value>::type> -const_div_by_expr operator/(const U &v, const expression &arg) + typename = typename std::enable_if::value>::type> +const_div_by_expr operator/( + const RealType2 &v, const expression &arg) { - return const_div_by_expr(arg, static_cast(v)); + return const_div_by_expr(arg, static_cast(v)); } -template::value>::type> -div_by_const_expr operator/(const expression &arg, const U &v) + typename = typename std::enable_if::value>::type> +div_by_const_expr operator/( + const expression &arg, const RealType2 &v) { - return div_by_const_expr(arg, static_cast(v)); + return div_by_const_expr(arg, static_cast(v)); } } // namespace reverse_mode diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp index cedaefe08..79616b317 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp @@ -10,147 +10,153 @@ namespace boost { namespace math { namespace differentiation { namespace reverse_mode { -template -bool operator==(const expression &lhs, const expression &rhs) +template +bool operator==(const expression &lhs, + const expression &rhs) { return lhs.evaluate() == rhs.evaluate(); } -template::value>::type> -bool operator==(const expression &lhs, const U &rhs) +template::value>::type> +bool operator==(const expression &lhs, const RealType2 &rhs) { - return lhs.evaluate() == static_cast(rhs); + return lhs.evaluate() == static_cast(rhs); } -template::value>::type> -bool operator==(const U &lhs, const expression &rhs) +template::value>::type> +bool operator==(const RealType2 &lhs, const expression &rhs) { return lhs == rhs.evaluate(); } -template -bool operator!=(const expression &lhs, const expression &rhs) +template +bool operator!=(const expression &lhs, + const expression &rhs) { return lhs.evaluate() != rhs.evaluate(); } -template::value>::type> -bool operator!=(const expression &lhs, const U &rhs) +template::value>::type> +bool operator!=(const expression &lhs, const RealType2 &rhs) { return lhs.evaluate() != rhs; } -template::value>::type> -bool operator!=(const U &lhs, const expression &rhs) +template::value>::type> +bool operator!=(const RealType2 &lhs, const expression &rhs) { return lhs != rhs.evaluate(); } -template -bool operator<(const expression &lhs, const expression &rhs) +template +bool operator<(const expression &lhs, + const expression &rhs) { return lhs.evaluate() < rhs.evaluate(); } -template::value>::type> -bool operator<(const expression &lhs, const U &rhs) +template::value>::type> +bool operator<(const expression &lhs, const RealType2 &rhs) { return lhs.evaluate() < rhs; } -template::value>::type> -bool operator<(const U &lhs, const expression &rhs) +template::value>::type> +bool operator<(const RealType2 &lhs, const expression &rhs) { return lhs < rhs.evaluate(); } -template -bool operator>(const expression &lhs, const expression &rhs) +template +bool operator>(const expression &lhs, + const expression &rhs) { return lhs.evaluate() > rhs.evaluate(); } -template::value>::type> -bool operator>(const expression &lhs, const U &rhs) +template::value>::type> +bool operator>(const expression &lhs, const RealType2 &rhs) { return lhs.evaluate() > rhs; } -template::value>::type> -bool operator>(const U &lhs, const expression &rhs) +template::value>::type> +bool operator>(const RealType2 &lhs, const expression &rhs) { return lhs > rhs.evaluate(); } -template -bool operator<=(const expression &lhs, const expression &rhs) +template +bool operator<=(const expression &lhs, + const expression &rhs) { return lhs.evaluate() <= rhs.evaluate(); } -template::value>::type> -bool operator<=(const expression &lhs, const U &rhs) +template::value>::type> +bool operator<=(const expression &lhs, const RealType2 &rhs) { return lhs.evaluate() <= rhs; } -template::value>::type> -bool operator<=(const U &lhs, const expression &rhs) +template::value>::type> +bool operator<=(const RealType2 &lhs, const expression &rhs) { return lhs <= rhs.evaluate(); } -template -bool operator>=(const expression &lhs, const expression &rhs) +template +bool operator>=(const expression &lhs, + const expression &rhs) { return lhs.evaluate() >= rhs.evaluate(); } -template::value>::type> -bool operator>=(const expression &lhs, const U &rhs) +template::value>::type> +bool operator>=(const expression &lhs, const RealType2 &rhs) { return lhs.evaluate() >= rhs; } -template::value>::type> -bool operator>=(const U &lhs, const expression &rhs) +template::value>::type> +bool operator>=(const RealType2 &lhs, const expression &rhs) { return lhs >= rhs.evaluate(); } diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp index 27c84ebb6..8e6697ae5 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp @@ -16,73 +16,84 @@ namespace math { namespace differentiation { namespace reverse_mode { -template +template struct erf_expr; -template +template struct erfc_expr; -template +template struct erf_inv_expr; -template +template struct erfc_inv_expr; -template -erf_expr erf(const expression &arg) +template +erf_expr erf(const expression &arg) { - return erf_expr(arg, 0.0); + return erf_expr(arg, 0.0); } -template -erfc_expr erfc(const expression &arg) +template +erfc_expr erfc(const expression &arg) { - return erfc_expr(arg, 0.0); + return erfc_expr(arg, 0.0); } -template -erf_inv_expr erf_inv(const expression &arg) +template +erf_inv_expr erf_inv( + const expression &arg) { - return erf_inv_expr(arg, 0.0); + return erf_inv_expr(arg, 0.0); } -template -erfc_inv_expr erfc_inv(const expression &arg) +template +erfc_inv_expr erfc_inv( + const expression &arg) { - return erfc_inv_expr(arg, 0.0); + return erfc_inv_expr(arg, 0.0); } -template -struct erf_expr : public abstract_unary_expression> +template +struct erf_expr : public abstract_unary_expression> { /** @brief erf(x) * * d/dx erf(x) = 2*exp(x^2)/sqrt(pi) * * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit erf_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit erf_expr(const expression &arg_expr, const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<(DerivativeOrder > 1)>( [this](auto &&x) { return reverse_mode::erf(std::forward(x)); }, [this](auto &&x) { return boost::math::erf(std::forward(x)); }, this->arg.evaluate()); } static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { BOOST_MATH_STD_USING - return static_cast(2.0) * exp(-argv * argv) / sqrt(constants::pi()); + return static_cast(2.0) * exp(-argv * argv) / sqrt(constants::pi()); } }; -template -struct erfc_expr : public abstract_unary_expression> +template +struct erfc_expr : public abstract_unary_expression> { /** @brief erfc(x) * @@ -90,29 +101,35 @@ struct erfc_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit erfc_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit erfc_expr(const expression &arg_expr, const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<((DerivativeOrder > 1))>( [this](auto &&x) { return reverse_mode::erfc(std::forward(x)); }, [this](auto &&x) { return boost::math::erfc(std::forward(x)); }, this->arg.evaluate()); } static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { BOOST_MATH_STD_USING - return static_cast(-2.0) * exp(-argv * argv) / sqrt(constants::pi()); + return static_cast(-2.0) * exp(-argv * argv) / sqrt(constants::pi()); } }; -template -struct erf_inv_expr : public abstract_unary_expression> +template +struct erf_inv_expr : public abstract_unary_expression> { /** @brief erf(x) * @@ -120,39 +137,47 @@ struct erf_inv_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit erf_inv_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit erf_inv_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<((DerivativeOrder > 1))>( [this](auto &&x) { return reverse_mode::erf_inv(std::forward(x)); }, [this](auto &&x) { return boost::math::erf_inv(std::forward(x)); }, this->arg.evaluate()); } static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { BOOST_MATH_STD_USING - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<((DerivativeOrder > 1))>( [](auto &&x) { - return static_cast(0.5) * sqrt(constants::pi()) + return static_cast(0.5) * sqrt(constants::pi()) * reverse_mode::exp( - reverse_mode::pow(reverse_mode::erf_inv(x), static_cast(2.0))); + reverse_mode::pow(reverse_mode::erf_inv(x), static_cast(2.0))); }, [](auto &&x) { - return static_cast(0.5) * sqrt(constants::pi()) - * exp(pow(boost::math::erf_inv(x), static_cast(2.0))); + return static_cast(0.5) * sqrt(constants::pi()) + * exp(pow(boost::math::erf_inv(x), static_cast(2.0))); }, argv); } }; -template -struct erfc_inv_expr : public abstract_unary_expression> +template +struct erfc_inv_expr + : public abstract_unary_expression> { /** @brief erfc(x) * @@ -160,32 +185,36 @@ struct erfc_inv_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit erfc_inv_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit erfc_inv_expr(const expression &arg_expr, + const RealType v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<((DerivativeOrder > 1))>( [this](auto &&x) { return reverse_mode::erfc_inv(std::forward(x)); }, [this](auto &&x) { return boost::math::erfc_inv(std::forward(x)); }, this->arg.evaluate()); } static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { BOOST_MATH_STD_USING - return detail::if_functional_dispatch<(order > 1)>( + return detail::if_functional_dispatch<((DerivativeOrder > 1))>( [](auto &&x) { - return static_cast(-0.5) * sqrt(constants::pi()) - * reverse_mode::exp( - reverse_mode::pow(reverse_mode::erfc_inv(x), static_cast(2.0))); + return static_cast(-0.5) * sqrt(constants::pi()) + * reverse_mode::exp(reverse_mode::pow(reverse_mode::erfc_inv(x), + static_cast(2.0))); }, [](auto &&x) { - return static_cast(-0.5) * sqrt(constants::pi()) - * exp(pow(boost::math::erfc_inv(x), static_cast(2.0))); + return static_cast(-0.5) * sqrt(constants::pi()) + * exp(pow(boost::math::erfc_inv(x), static_cast(2.0))); }, argv); } diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp index 31d357ca9..ccdabf118 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp @@ -16,19 +16,24 @@ namespace reverse_mode { struct expression_base {}; -template +template struct expression; -template +template class rvar; -template +template struct abstract_binary_expression; -template +template + struct abstract_unary_expression; -template +template class gradient_node; // forward declaration for tape namespace detail { @@ -59,50 +64,54 @@ struct count_rvar_impl { static constexpr std::size_t value = 0; }; -template -struct count_rvar_impl, order> +template +struct count_rvar_impl, DerivativeOrder> { static constexpr std::size_t value = 1; }; -template -struct count_rvar_impl::value - && !std::is_same>::value - && !has_unary_sub_type::value>> +template +struct count_rvar_impl< + RealType, + DerivativeOrder, + std::enable_if_t::value + && !std::is_same>::value + && !has_unary_sub_type::value>> { - static constexpr std::size_t value = count_rvar_impl::value - + count_rvar_impl::value; + static constexpr std::size_t value + = count_rvar_impl::value + + count_rvar_impl::value; }; -template +template struct count_rvar_impl< - T, - order, - typename std::enable_if_t::value - && !std::is_same>::value - && !has_binary_sub_types::value>> + RealType, + DerivativeOrder, + typename std::enable_if_t< + has_unary_sub_type::value + && !std::is_same>::value + && !has_binary_sub_types::value>> { - static constexpr std::size_t value = count_rvar_impl::value; + static constexpr std::size_t value + = count_rvar_impl::value; }; -template -constexpr std::size_t count_rvars = detail::count_rvar_impl::value; +template +constexpr std::size_t count_rvars = detail::count_rvar_impl::value; template struct is_expression : std::is_base_of::type> {}; -template +template struct rvar_type_impl { - using type = rvar; + using type = rvar; }; -template -struct rvar_type_impl +template +struct rvar_type_impl { - using type = T; + using type = RealType; }; } // namespace detail @@ -110,63 +119,69 @@ struct rvar_type_impl template using rvar_t = typename detail::rvar_type_impl::type; -template +template struct expression : expression_base { /* @brief * base expression class * */ - using value_type = T; - static constexpr size_t order_v = order; - using derived_type = derived_expression; + using value_type = RealType; + static constexpr size_t order_v = DerivativeOrder; + using derived_type = DerivedExpression; static constexpr size_t num_literals = 0; - using inner_t = rvar_t; - inner_t evaluate() const { return static_cast(this)->evaluate(); } + using inner_t = rvar_t; + inner_t evaluate() const { return static_cast(this)->evaluate(); } template - void propagatex(gradient_node *node, inner_t adj) const + void propagatex(gradient_node *node, inner_t adj) const { - return static_cast(this)->template propagatex(node, - adj); - } + return static_cast(this)->template propagatex(node, + adj); + }; }; -template +template struct abstract_binary_expression - : public expression> + : public expression< + RealType, + DerivativeOrder, + abstract_binary_expression> { using lhs_type = LHS; using rhs_type = RHS; - using value_type = T; - using inner_t = rvar_t; + using value_type = RealType; + using inner_t = rvar_t; const lhs_type lhs; const rhs_type rhs; - explicit abstract_binary_expression(const expression &left_hand_expr, - const expression &right_hand_expr) + explicit abstract_binary_expression( + const expression &left_hand_expr, + const expression &right_hand_expr) : lhs(static_cast(left_hand_expr)) , rhs(static_cast(right_hand_expr)){}; inner_t evaluate() const { - return static_cast(this)->evaluate(); + return static_cast(this)->evaluate(); }; template - void propagatex(gradient_node *node, inner_t adj) const + void propagatex(gradient_node *node, inner_t adj) const { - const inner_t lv = lhs.evaluate(); - const inner_t rv = rhs.evaluate(); - const inner_t v = evaluate(); - const inner_t partial_l = concrete_binary_operation::left_derivative(lv, rv, v); - const inner_t partial_r = concrete_binary_operation::right_derivative(lv, rv, v); + inner_t lv = lhs.evaluate(); + inner_t rv = rhs.evaluate(); + inner_t v = evaluate(); + inner_t partial_l = ConcreteBinaryOperation::left_derivative(lv, rv, v); + inner_t partial_r = ConcreteBinaryOperation::right_derivative(lv, rv, v); - constexpr size_t num_lhs_args = detail::count_rvars; - constexpr size_t num_rhs_args = detail::count_rvars; + constexpr size_t num_lhs_args = detail::count_rvars; + constexpr size_t num_rhs_args = detail::count_rvars; propagate_lhs(node, adj * partial_l); propagate_rhs(node, adj * partial_r); @@ -178,7 +193,7 @@ private: template 0), int>::type = 0> - void propagate_lhs(gradient_node *node, inner_t adj) const + void propagate_lhs(gradient_node *node, inner_t adj) const { lhs.template propagatex(node, adj); } @@ -186,13 +201,13 @@ private: template::type = 0> - void propagate_lhs(gradient_node *, inner_t) const + void propagate_lhs(gradient_node *, inner_t) const {} template 0), int>::type = 0> - void propagate_rhs(gradient_node *node, inner_t adj) const + void propagate_rhs(gradient_node *node, inner_t adj) const { rhs.template propagatex(node, adj); } @@ -200,32 +215,37 @@ private: template::type = 0> - void propagate_rhs(gradient_node *, inner_t) const + void propagate_rhs(gradient_node *, inner_t) const {} }; -template +template + struct abstract_unary_expression - : public expression> + : public expression< + RealType, + DerivativeOrder, + abstract_unary_expression> { using arg_type = ARG; - using value_type = T; - using inner_t = rvar_t; + using value_type = RealType; + using inner_t = rvar_t; const arg_type arg; - const T constant; - explicit abstract_unary_expression(const expression &arg_expr, const T &constant) + const RealType constant; + explicit abstract_unary_expression(const expression &arg_expr, + const RealType &constant) : arg(static_cast(arg_expr)) , constant(constant){}; inner_t evaluate() const { - return static_cast(this)->evaluate(); + return static_cast(this)->evaluate(); }; template - void propagatex(gradient_node *node, inner_t adj) const + void propagatex(gradient_node *node, inner_t adj) const { inner_t argv = arg.evaluate(); inner_t v = evaluate(); - inner_t partial_arg = concrete_unary_operation::derivative(argv, v, constant); + inner_t partial_arg = ConcreteUnaryOperation::derivative(argv, v, constant); arg.template propagatex(node, adj * partial_arg); } diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp index 0c95128fb..b0b5ab055 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp @@ -188,7 +188,7 @@ public: bool operator!() const noexcept { return storage_ == nullptr; } }; /* memory management helps for tape */ -template +template class flat_linear_allocator { /** @brief basically a vector> @@ -198,8 +198,8 @@ class flat_linear_allocator public: // store vector of unique pointers to arrays // to avoid vector reference invalidation - using buffer_type = std::array; - using buffer_ptr = std::unique_ptr>; + using buffer_type = std::array; + using buffer_ptr = std::unique_ptr>; private: std::vector data_; @@ -207,14 +207,16 @@ private: std::vector checkpoints_; //{0}; public: - friend class flat_linear_allocator_iterator, buffer_size>; - friend class flat_linear_allocator_iterator, + friend class flat_linear_allocator_iterator, buffer_size>; - using value_type = T; + friend class flat_linear_allocator_iterator, + buffer_size>; + using value_type = RealType; using iterator - = flat_linear_allocator_iterator, buffer_size>; + = flat_linear_allocator_iterator, buffer_size>; using const_iterator - = flat_linear_allocator_iterator, buffer_size>; + = flat_linear_allocator_iterator, + buffer_size>; size_t buffer_id() const noexcept { return total_size_ / buffer_size; } size_t item_id() const noexcept { return total_size_ % buffer_size; } @@ -242,7 +244,7 @@ public: for (size_t i = 0; i < total_size_; ++i) { size_t bid = i / buffer_size; size_t iid = i % buffer_size; - (*data_[bid])[iid].~T(); + (*data_[bid])[iid].~RealType(); } } /** @brief @@ -277,7 +279,7 @@ public: void rewind_to_last_checkpoint() { total_size_ = checkpoints_.back(); } void rewind_to_checkpoint_at(size_t index) { total_size_ = checkpoints_[index]; } - void fill(const T &val) + void fill(const RealType &val) { for (size_t i = 0; i < total_size_; ++i) { size_t bid = i / buffer_size; @@ -296,8 +298,8 @@ public: size_t bid = buffer_id(); size_t iid = item_id(); - T *ptr = &(*data_[bid])[iid]; - new (ptr) T(); + RealType *ptr = &(*data_[bid])[iid]; + new (ptr) RealType(); ++total_size_; return iterator(this, total_size_ - 1); }; @@ -312,8 +314,8 @@ public: } BOOST_MATH_ASSERT(buffer_id() < data_.size()); BOOST_MATH_ASSERT(item_id() < buffer_size); - T *ptr = &(*data_[buffer_id()])[item_id()]; - new (ptr) T(std::forward(args)...); + RealType *ptr = &(*data_[buffer_id()])[item_id()]; + new (ptr) RealType(std::forward(args)...); ++total_size_; return iterator(this, total_size_ - 1); } @@ -325,27 +327,27 @@ public: size_t bid = buffer_id(); size_t iid = item_id(); if (iid + n < buffer_size) { - T *ptr = &(*data_[bid])[iid]; + RealType *ptr = &(*data_[bid])[iid]; for (size_t i = 0; i < n; ++i) { - new (ptr + i) T(); + new (ptr + i) RealType(); } total_size_ += n; return iterator(this, total_size_ - n, total_size_ - n, total_size_); } else { size_t allocs_in_curr_buffer = buffer_size - iid; size_t allocs_in_next_buffer = n - (buffer_size - iid); - T *ptr = &(*data_[bid])[iid]; + RealType *ptr = &(*data_[bid])[iid]; for (size_t i = 0; i < allocs_in_curr_buffer; ++i) { - new (ptr + i) T(); + new (ptr + i) RealType(); } allocate_buffer(); bid = data_.size() - 1; iid = 0; total_size_ += n; - T *ptr2 = &(*data_[bid])[iid]; + RealType *ptr2 = &(*data_[bid])[iid]; for (size_t i = 0; i < allocs_in_next_buffer; i++) { - new (ptr2 + i) T(); + new (ptr2 + i) RealType(); } return iterator(this, total_size_ - n, total_size_ - n, total_size_); } @@ -358,27 +360,27 @@ public: size_t bid = buffer_id(); size_t iid = item_id(); if (iid + n < buffer_size) { - T *ptr = &(*data_[bid])[iid]; + RealType *ptr = &(*data_[bid])[iid]; for (size_t i = 0; i < n; ++i) { - new (ptr + i) T(); + new (ptr + i) RealType(); } total_size_ += n; return iterator(this, total_size_ - n, total_size_ - n, total_size_); } else { size_t allocs_in_curr_buffer = buffer_size - iid; size_t allocs_in_next_buffer = n - (buffer_size - iid); - T *ptr = &(*data_[bid])[iid]; + RealType *ptr = &(*data_[bid])[iid]; for (size_t i = 0; i < allocs_in_curr_buffer; ++i) { - new (ptr + i) T(); + new (ptr + i) RealType(); } allocate_buffer(); bid = data_.size() - 1; iid = 0; total_size_ += n; - T *ptr2 = &(*data_[bid])[iid]; + RealType *ptr2 = &(*data_[bid])[iid]; for (size_t i = 0; i < allocs_in_next_buffer; i++) { - new (ptr2 + i) T(); + new (ptr2 + i) RealType(); } return iterator(this, total_size_ - n, total_size_ - n, total_size_); } @@ -405,19 +407,19 @@ public: /** @brief searches for item in allocator * only used to find gradient nodes for propagation */ - iterator find(const T *const item) + iterator find(const RealType *const item) { - return std::find_if(begin(), end(), [&](const T &val) { return &val == item; }); + return std::find_if(begin(), end(), [&](const RealType &val) { return &val == item; }); } /** @brief vector like access, * currently unused anywhere but very useful for debugging */ - T &operator[](std::size_t i) + RealType &operator[](std::size_t i) { BOOST_MATH_ASSERT(i < total_size_); return (*data_[i / buffer_size])[i % buffer_size]; } - const T &operator[](std::size_t i) const + const RealType &operator[](std::size_t i) const { BOOST_MATH_ASSERT(i < total_size_); return (*data_[i / buffer_size])[i % buffer_size]; diff --git a/include/boost/math/differentiation/detail/reverse_mode_autodiff_stl_overloads.hpp b/include/boost/math/differentiation/detail/reverse_mode_autodiff_stl_overloads.hpp index 1c9c4a154..aa8716d4a 100644 --- a/include/boost/math/differentiation/detail/reverse_mode_autodiff_stl_overloads.hpp +++ b/include/boost/math/differentiation/detail/reverse_mode_autodiff_stl_overloads.hpp @@ -15,8 +15,8 @@ namespace boost { namespace math { namespace differentiation { namespace reverse_mode { -template -struct fabs_expr : public abstract_unary_expression> +template +struct fabs_expr : public abstract_unary_expression> { /** @brief * |x| @@ -27,10 +27,13 @@ struct fabs_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit fabs_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit fabs_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -39,14 +42,14 @@ struct fabs_expr : public abstract_unary_expression 0.0 ? inner_t{1.0} : inner_t{-1.0}; } }; -template -struct ceil_expr : public abstract_unary_expression> +template +struct ceil_expr : public abstract_unary_expression> { /** @brief ceil(1.11) = 2.0 * @@ -56,10 +59,13 @@ struct ceil_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit ceil_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit ceil_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -68,14 +74,14 @@ struct ceil_expr : public abstract_unary_expression -struct floor_expr : public abstract_unary_expression> +template +struct floor_expr : public abstract_unary_expression> { /** @brief floor(1.11) = 1.0, floor(-1.11) = 2 * @@ -85,10 +91,14 @@ struct floor_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit floor_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit floor_expr(const expression &arg_expr, + const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -97,14 +107,14 @@ struct floor_expr : public abstract_unary_expression -struct trunc_expr : public abstract_unary_expression> +template +struct trunc_expr : public abstract_unary_expression> { /** @brief trunc(1.11) = 1.0, trunc(-1.11) = -1.0 * @@ -114,10 +124,14 @@ struct trunc_expr : public abstract_unary_expression; + using inner_t = rvar_t; - explicit trunc_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit trunc_expr(const expression &arg_expr, + const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -126,24 +140,27 @@ struct trunc_expr : public abstract_unary_expression -struct exp_expr : public abstract_unary_expression> +template +struct exp_expr : public abstract_unary_expression> { /** @brief exp(x) * * d/dx exp(x) = exp(x) * * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit exp_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit exp_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -152,26 +169,26 @@ struct exp_expr : public abstract_unary_expression +template struct pow_expr - : public abstract_binary_expression> + : public abstract_binary_expression> { /** @brief pow(x,y) * d/dx pow(x,y) = y pow (x, y-1) * d/dy pow(x,y) = pow(x,y) log(x) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit pow_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( + explicit pow_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>( left_hand_expr, right_hand_expr) {} @@ -183,7 +200,7 @@ struct pow_expr static const inner_t left_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/) { BOOST_MATH_STD_USING - return r * pow(l, r - static_cast(1.0)); + return r * pow(l, r - static_cast(1.0)); }; static const inner_t right_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/) { @@ -192,64 +209,75 @@ struct pow_expr }; }; -template +template struct expr_pow_float_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief pow(rvar,float) */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit expr_pow_float_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, - v){}; + explicit expr_pow_float_expr(const expression &arg_expr, + const RealType &v) + : abstract_unary_expression>(arg_expr, + v){}; inner_t evaluate() const { BOOST_MATH_STD_USING return pow(this->arg.evaluate(), this->constant); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const RealType &constant) { BOOST_MATH_STD_USING return inner_t{constant} * pow(argv, inner_t{constant - 1}); } }; -template +template struct float_pow_expr_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief pow(float, rvar) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit float_pow_expr_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, - v){}; + explicit float_pow_expr_expr(const expression &arg_expr, + const RealType &v) + : abstract_unary_expression>(arg_expr, + v){}; inner_t evaluate() const { BOOST_MATH_STD_USING return pow(this->constant, this->arg.evaluate()); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const RealType &constant) { BOOST_MATH_STD_USING return pow(constant, argv) * log(constant); } }; -template -struct sqrt_expr : public abstract_unary_expression> +template +struct sqrt_expr : public abstract_unary_expression> { /** @brief sqrt(x) * d/dx sqrt(x) = 1/(2 sqrt(x)) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit sqrt_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit sqrt_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -258,23 +286,26 @@ struct sqrt_expr : public abstract_unary_expression(1.0) / (static_cast(2.0) * sqrt(argv)); + return static_cast(1.0) / (static_cast(2.0) * sqrt(argv)); } }; -template -struct log_expr : public abstract_unary_expression> +template +struct log_expr : public abstract_unary_expression> { /** @brief log(x) * d/dx log(x) = 1/x * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit log_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit log_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -283,22 +314,25 @@ struct log_expr : public abstract_unary_expression(1.0) / argv; + return static_cast(1.0) / argv; } }; -template -struct cos_expr : public abstract_unary_expression> +template +struct cos_expr : public abstract_unary_expression> { /** @brief cos(x) * d/dx cos(x) = -sin(x) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit cos_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit cos_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -307,25 +341,28 @@ struct cos_expr : public abstract_unary_expression -struct sin_expr : public abstract_unary_expression> +template +struct sin_expr : public abstract_unary_expression> { /** @brief sin(x) * d/dx sin(x) = cos(x) * */ using arg_type = ARG; - using value_type = T; - using inner_t = rvar_t; + using value_type = RealType; + using inner_t = rvar_t; - explicit sin_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit sin_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -334,23 +371,23 @@ struct sin_expr : public abstract_unary_expression -struct tan_expr : public abstract_unary_expression> +template +struct tan_expr : public abstract_unary_expression> { /** @brief tan(x) * d/dx tan(x) = 1/cos^2(x) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit tan_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit tan_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -359,23 +396,23 @@ struct tan_expr : public abstract_unary_expression(1.0) / (cos(argv) * cos(argv)); + return static_cast(1.0) / (cos(argv) * cos(argv)); } }; -template -struct acos_expr : public abstract_unary_expression> +template +struct acos_expr : public abstract_unary_expression> { /** @brief acos(x) * d/dx acos(x) = -1/sqrt(1-x^2) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit acos_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit acos_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -384,25 +421,25 @@ struct acos_expr : public abstract_unary_expression(-1.0) / sqrt(static_cast(1.0) - argv * argv); + return static_cast(-1.0) / sqrt(static_cast(1.0) - argv * argv); } }; -template -struct asin_expr : public abstract_unary_expression> +template +struct asin_expr : public abstract_unary_expression> { /** @brief asin(x) * d/dx asin = 1/sqrt(1-x^2) * */ using arg_type = ARG; - using value_type = T; - using inner_t = rvar_t; + using value_type = RealType; + using inner_t = rvar_t; - explicit asin_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit asin_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -411,23 +448,23 @@ struct asin_expr : public abstract_unary_expression(1.0) / sqrt(static_cast(1.0) - argv * argv); + return static_cast(1.0) / sqrt(static_cast(1.0) - argv * argv); } }; -template -struct atan_expr : public abstract_unary_expression> +template +struct atan_expr : public abstract_unary_expression> { /** @brief atan(x) * d/dx atan(x) = 1/x^2+1 * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit atan_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit atan_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -436,23 +473,23 @@ struct atan_expr : public abstract_unary_expression(1.0) / (static_cast(1.0) + argv * argv); + return static_cast(1.0) / (static_cast(1.0) + argv * argv); } }; -template +template struct atan2_expr - : public abstract_binary_expression> + : public abstract_binary_expression> { /** @brief atan2(x,y) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit atan2_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( + explicit atan2_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>( left_hand_expr, right_hand_expr) {} @@ -471,16 +508,16 @@ struct atan2_expr }; }; -template +template struct atan2_left_float_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief atan2(float,rvar) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit atan2_left_float_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, + explicit atan2_left_float_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const @@ -488,22 +525,22 @@ struct atan2_left_float_expr BOOST_MATH_STD_USING return atan2(this->constant, this->arg.evaluate()); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const RealType &constant) { return -constant / (constant * constant + argv * argv); } }; -template +template struct atan2_right_float_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief atan2(rvar,float) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit atan2_right_float_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, + explicit atan2_right_float_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const @@ -511,22 +548,22 @@ struct atan2_right_float_expr BOOST_MATH_STD_USING return atan2(this->arg.evaluate(), this->constant); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const RealType &constant) { return constant / (constant * constant + argv * argv); } }; -template -struct round_expr : public abstract_unary_expression> +template +struct round_expr : public abstract_unary_expression> { /** @brief round(x) * d/dx round = 0 * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit round_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit round_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -535,22 +572,22 @@ struct round_expr : public abstract_unary_expression -struct sinh_expr : public abstract_unary_expression> +template +struct sinh_expr : public abstract_unary_expression> { /** @brief sinh(x) * d/dx sinh(x) = cosh * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit sinh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit sinh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -559,23 +596,23 @@ struct sinh_expr : public abstract_unary_expression -struct cosh_expr : public abstract_unary_expression> +template +struct cosh_expr : public abstract_unary_expression> { /** @brief cosh(x) * d/dx cosh(x) = sinh * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit cosh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit cosh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -584,22 +621,22 @@ struct cosh_expr : public abstract_unary_expression -struct tanh_expr : public abstract_unary_expression> +template +struct tanh_expr : public abstract_unary_expression> { /** @brief tanh(x) * d/dx tanh(x) = 1/cosh^2 * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit tanh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit tanh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -608,23 +645,23 @@ struct tanh_expr : public abstract_unary_expression(1.0) / (cosh(argv) * cosh(argv)); + return static_cast(1.0) / (cosh(argv) * cosh(argv)); } }; -template -struct log10_expr : public abstract_unary_expression> +template +struct log10_expr : public abstract_unary_expression> { /** @brief log10(x) * d/dx log10(x) = 1/(x * log(10)) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit log10_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit log10_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -633,23 +670,23 @@ struct log10_expr : public abstract_unary_expression(1.0) / (argv * log(static_cast(10.0))); + return static_cast(1.0) / (argv * log(static_cast(10.0))); } }; -template -struct acosh_expr : public abstract_unary_expression> +template +struct acosh_expr : public abstract_unary_expression> { /** @brief acosh(x) * d/dx acosh(x) = 1/(sqrt(x-1)sqrt(x+1) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit acosh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit acosh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -658,24 +695,24 @@ struct acosh_expr : public abstract_unary_expression(1.0) - / (sqrt(argv - static_cast(1.0)) * sqrt(argv + static_cast(1.0))); + return static_cast(1.0) + / (sqrt(argv - static_cast(1.0)) * sqrt(argv + static_cast(1.0))); } }; -template -struct asinh_expr : public abstract_unary_expression> +template +struct asinh_expr : public abstract_unary_expression> { /** @brief asinh(x) * d/dx asinh(x) = 1/(sqrt(1+x^2)) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit asinh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit asinh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -684,23 +721,23 @@ struct asinh_expr : public abstract_unary_expression(1.0) / (sqrt(static_cast(1.0) + argv * argv)); + return static_cast(1.0) / (sqrt(static_cast(1.0) + argv * argv)); } }; -template -struct atanh_expr : public abstract_unary_expression> +template +struct atanh_expr : public abstract_unary_expression> { /** @brief atanh(x) * d/dx atanh(x) = 1/(1-x^2) * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit atanh_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, v){}; + explicit atanh_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const { @@ -709,23 +746,23 @@ struct atanh_expr : public abstract_unary_expression(1.0) / (static_cast(1.0) - argv * argv); + return static_cast(1.0) / (static_cast(1.0) - argv * argv); } }; -template +template struct fmod_expr - : public abstract_binary_expression> + : public abstract_binary_expression> { /** @brief * */ - using inner_t = rvar_t; + using inner_t = rvar_t; // Explicitly define constructor to forward to base class - explicit fmod_expr(const expression &left_hand_expr, - const expression &right_hand_expr) - : abstract_binary_expression>( + explicit fmod_expr(const expression &left_hand_expr, + const expression &right_hand_expr) + : abstract_binary_expression>( left_hand_expr, right_hand_expr) {} @@ -743,20 +780,20 @@ struct fmod_expr static const inner_t right_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/) { BOOST_MATH_STD_USING - return static_cast(-1.0) * trunc(l / r); + return static_cast(-1.0) * trunc(l / r); }; }; -template +template struct fmod_left_float_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit fmod_left_float_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, + explicit fmod_left_float_expr(const expression &arg_expr, const RealType &v) + : abstract_unary_expression>(arg_expr, v){}; inner_t evaluate() const @@ -764,23 +801,27 @@ struct fmod_left_float_expr BOOST_MATH_STD_USING return fmod(this->constant, this->arg.evaluate()); } - static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant) + static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const RealType &constant) { - return static_cast(-1.0) * trunc(constant / argv); + return static_cast(-1.0) * trunc(constant / argv); } }; -template +template struct fmod_right_float_expr - : public abstract_unary_expression> + : public abstract_unary_expression> { /** @brief * */ - using inner_t = rvar_t; + using inner_t = rvar_t; - explicit fmod_right_float_expr(const expression &arg_expr, const T v) - : abstract_unary_expression>(arg_expr, - v){}; + explicit fmod_right_float_expr(const expression &arg_expr, + const RealType &v) + : abstract_unary_expression>(arg_expr, + v){}; inner_t evaluate() const { @@ -789,253 +830,254 @@ struct fmod_right_float_expr } static const inner_t derivative(const inner_t & /*argv*/, const inner_t & /*v*/, - const T & /*constant*/) + const RealType & /*constant*/) { return inner_t{1.0}; } }; /**************************************************************************************************/ -template -fabs_expr fabs(const expression &arg) +template +fabs_expr fabs(const expression &arg) { - return fabs_expr(arg, static_cast(0.0)); + return fabs_expr(arg, static_cast(0.0)); } -template -auto abs(const expression &arg) +template +auto abs(const expression &arg) { return fabs(arg); } -template -ceil_expr ceil(const expression &arg) +template +ceil_expr ceil(const expression &arg) { - return ceil_expr(arg, static_cast(0.0)); + return ceil_expr(arg, static_cast(0.0)); } -template -floor_expr floor(const expression &arg) +template +floor_expr floor(const expression &arg) { - return floor_expr(arg, static_cast(0.0)); + return floor_expr(arg, static_cast(0.0)); } -template -exp_expr exp(const expression &arg) +template +exp_expr exp(const expression &arg) { - return exp_expr(arg, static_cast(0.0)); + return exp_expr(arg, static_cast(0.0)); } -template -pow_expr pow(const expression &lhs, - const expression &rhs) +template +pow_expr pow(const expression &lhs, + const expression &rhs) { - return pow_expr(lhs, rhs); + return pow_expr(lhs, rhs); } -template::value>::type> -expr_pow_float_expr pow(const expression &arg, const U &v) + typename = typename std::enable_if::value>::type> +expr_pow_float_expr pow( + const expression &arg, const RealType2 &v) { - return expr_pow_float_expr(arg, static_cast(v)); + return expr_pow_float_expr(arg, static_cast(v)); }; -template -float_pow_expr_expr pow(const T &v, const expression &arg) +template +float_pow_expr_expr pow(const RealType &v, const expression &arg) { - return float_pow_expr_expr(arg, v); + return float_pow_expr_expr(arg, v); }; -template -log_expr log(const expression &arg) +template +log_expr log(const expression &arg) { - return log_expr(arg, static_cast(0.0)); + return log_expr(arg, static_cast(0.0)); }; -template -sqrt_expr sqrt(const expression &arg) +template +sqrt_expr sqrt(const expression &arg) { - return sqrt_expr(arg, static_cast(0.0)); + return sqrt_expr(arg, static_cast(0.0)); }; -template -auto frexp(const expression &arg, int *i) +template +auto frexp(const expression &arg, int *i) { BOOST_MATH_STD_USING frexp(arg.evaluate(), i); - return arg / pow(static_cast(2.0), *i); + return arg / pow(static_cast(2.0), *i); } -template -auto ldexp(const expression &arg, const int &i) +template +auto ldexp(const expression &arg, const int &i) { BOOST_MATH_STD_USING - return arg * pow(static_cast(2.0), i); + return arg * pow(static_cast(2.0), i); } -template -cos_expr cos(const expression &arg) +template +cos_expr cos(const expression &arg) { - return cos_expr(arg, static_cast(0.0)); + return cos_expr(arg, static_cast(0.0)); }; -template -sin_expr sin(const expression &arg) +template +sin_expr sin(const expression &arg) { - return sin_expr(arg, static_cast(0.0)); + return sin_expr(arg, static_cast(0.0)); }; -template -tan_expr tan(const expression &arg) +template +tan_expr tan(const expression &arg) { - return tan_expr(arg, static_cast(0.0)); + return tan_expr(arg, static_cast(0.0)); }; -template -acos_expr acos(const expression &arg) +template +acos_expr acos(const expression &arg) { - return acos_expr(arg, static_cast(0.0)); + return acos_expr(arg, static_cast(0.0)); }; -template -asin_expr asin(const expression &arg) +template +asin_expr asin(const expression &arg) { - return asin_expr(arg, static_cast(0.0)); + return asin_expr(arg, static_cast(0.0)); }; -template -atan_expr atan(const expression &arg) +template +atan_expr atan(const expression &arg) { - return atan_expr(arg, static_cast(0.0)); + return atan_expr(arg, static_cast(0.0)); }; -template -atan2_expr atan2(const expression &lhs, - const expression &rhs) +template +atan2_expr atan2(const expression &lhs, + const expression &rhs) { - return atan2_expr(lhs, rhs); + return atan2_expr(lhs, rhs); } -template -atan2_right_float_expr atan2(const expression &arg, const T &v) +template +atan2_right_float_expr atan2(const expression &arg, const RealType &v) { - return atan2_right_float_expr(arg, v); + return atan2_right_float_expr(arg, v); }; -template -atan2_left_float_expr atan2(const T &v, const expression &arg) +template +atan2_left_float_expr atan2(const RealType &v, const expression &arg) { - return atan2_left_float_expr(arg, v); + return atan2_left_float_expr(arg, v); }; -template -trunc_expr trunc(const expression &arg) +template +trunc_expr trunc(const expression &arg) { - return trunc_expr(arg, static_cast(0.0)); + return trunc_expr(arg, static_cast(0.0)); } -template -auto fmod(const expression &lhs, const expression &rhs) +template +auto fmod(const expression &lhs, const expression &rhs) { - return fmod_expr(lhs, rhs); + return fmod_expr(lhs, rhs); } -template -auto fmod(const expression &lhs, const T rhs) +template +auto fmod(const expression &lhs, const RealType rhs) { - return fmod_right_float_expr(lhs, rhs); + return fmod_right_float_expr(lhs, rhs); } -template -auto fmod(const T lhs, const expression &rhs) +template +auto fmod(const RealType lhs, const expression &rhs) { - return fmod_left_float_expr(rhs, lhs); + return fmod_left_float_expr(rhs, lhs); } -template -round_expr round(const expression &arg) +template +round_expr round(const expression &arg) { - return round_expr(arg, static_cast(0.0)); + return round_expr(arg, static_cast(0.0)); } -template -int iround(const expression &arg) +template +int iround(const expression &arg) { - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return iround(tmp.item()); } -template -long lround(const expression &arg) +template +long lround(const expression &arg) { BOOST_MATH_STD_USING - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return lround(tmp.item()); } -template -long long llround(const expression &arg) +template +long long llround(const expression &arg) { - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return llround(tmp.item()); } -template -int itrunc(const expression &arg) +template +int itrunc(const expression &arg) { - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return itrunc(tmp.item()); } -template -long ltrunc(const expression &arg) +template +long ltrunc(const expression &arg) { - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return ltrunc(tmp.item()); } -template -long long lltrunc(const expression &arg) +template +long long lltrunc(const expression &arg) { - rvar tmp = arg.evaluate(); + rvar tmp = arg.evaluate(); return lltrunc(tmp.item()); } -template -sinh_expr sinh(const expression &arg) +template +sinh_expr sinh(const expression &arg) { - return sinh_expr(arg, static_cast(0.0)); + return sinh_expr(arg, static_cast(0.0)); } -template -cosh_expr cosh(const expression &arg) +template +cosh_expr cosh(const expression &arg) { - return cosh_expr(arg, static_cast(0.0)); + return cosh_expr(arg, static_cast(0.0)); } -template -tanh_expr tanh(const expression &arg) +template +tanh_expr tanh(const expression &arg) { - return tanh_expr(arg, static_cast(0.0)); + return tanh_expr(arg, static_cast(0.0)); } -template -log10_expr log10(const expression &arg) +template +log10_expr log10(const expression &arg) { - return log10_expr(arg, static_cast(0.0)); + return log10_expr(arg, static_cast(0.0)); } -template -asinh_expr asinh(const expression &arg) +template +asinh_expr asinh(const expression &arg) { - return asinh_expr(arg, static_cast(0.0)); + return asinh_expr(arg, static_cast(0.0)); } -template -acosh_expr acosh(const expression &arg) +template +acosh_expr acosh(const expression &arg) { - return acosh_expr(arg, static_cast(0.0)); + return acosh_expr(arg, static_cast(0.0)); } -template -atanh_expr atanh(const expression &arg) +template +atanh_expr atanh(const expression &arg) { - return atanh_expr(arg, static_cast(0.0)); + return atanh_expr(arg, static_cast(0.0)); } } // namespace reverse_mode } // namespace differentiation