2
0
mirror of https://github.com/boostorg/math.git synced 2026-01-19 04:22:09 +00:00

modified template variable names, changed copies to references

This commit is contained in:
mzhelyez
2025-08-21 11:26:32 +02:00
parent 602f227398
commit b53cfc25c9
7 changed files with 1049 additions and 862 deletions

View File

@@ -35,40 +35,46 @@ namespace differentiation {
namespace reverse_mode {
/* forward declarations for utitlity functions */
template<typename T, size_t order, class derived_expression>
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
struct expression;
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
class rvar;
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
template<typename RealType,
size_t DerivativeOrder,
typename LHS,
typename RHS,
typename ConcreteBinaryOperation>
struct abstract_binary_expression;
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteBinaryOperation>
struct abstract_unary_expression;
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
class gradient_node; // forward declaration for tape
// manages nodes in computational graph
template<typename T, size_t order, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
template<typename RealType, size_t DerivativeOrder, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
class gradient_tape
{
/** @brief tape (graph) management class for autodiff
* holds all the data structures for autodiff */
private:
/* type decays to order - 1 to support higher order derivatives */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
/* adjoints are the overall derivative, and derivatives are the "local"
* derivative */
detail::flat_linear_allocator<inner_t, buffer_size> adjoints_;
detail::flat_linear_allocator<inner_t, buffer_size> derivatives_;
detail::flat_linear_allocator<gradient_node<T, order>, buffer_size> gradient_nodes_;
detail::flat_linear_allocator<gradient_node<T, order> *, buffer_size> argument_nodes_;
detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>, buffer_size>
gradient_nodes_;
detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *, buffer_size>
argument_nodes_;
// compile time check if emplace_back calls on zero
template<size_t n>
gradient_node<T, order> *fill_node_at_compile_time(std::true_type,
gradient_node<T, order> *node_ptr)
gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
std::true_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
{
node_ptr->derivatives_ = derivatives_.template emplace_back_n<n>();
node_ptr->argument_nodes_ = argument_nodes_.template emplace_back_n<n>();
@@ -76,8 +82,8 @@ private:
}
template<size_t n>
gradient_node<T, order> *fill_node_at_compile_time(std::false_type,
gradient_node<T, order> *node_ptr)
gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
std::false_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
{
node_ptr->derivatives_ = nullptr;
node_ptr->argument_adjoints_ = nullptr;
@@ -94,9 +100,11 @@ public:
using derivatives_iterator =
typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
using gradient_nodes_iterator =
typename detail::flat_linear_allocator<gradient_node<T, order>, buffer_size>::iterator;
typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>,
buffer_size>::iterator;
using argument_nodes_iterator =
typename detail::flat_linear_allocator<gradient_node<T, order> *, buffer_size>::iterator;
typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *,
buffer_size>::iterator;
gradient_tape() { clear(); };
@@ -114,9 +122,9 @@ public:
}
// no derivatives or arguments
gradient_node<T, order> *emplace_leaf_node()
gradient_node<RealType, DerivativeOrder> *emplace_leaf_node()
{
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
node->adjoint_ = adjoints_.emplace_back();
node->derivatives_ = derivatives_iterator(); // nullptr;
node->argument_nodes_ = argument_nodes_iterator(); // nullptr;
@@ -125,9 +133,9 @@ public:
};
// single argument, single derivative
gradient_node<T, order> *emplace_active_unary_node()
gradient_node<RealType, DerivativeOrder> *emplace_active_unary_node()
{
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
node->n_ = 1;
node->adjoint_ = adjoints_.emplace_back();
node->derivatives_ = derivatives_.emplace_back();
@@ -137,9 +145,9 @@ public:
// arbitrary number of arguments/derivatives (compile time)
template<size_t n>
gradient_node<T, order> *emplace_active_multi_node()
gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node()
{
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
node->n_ = n;
node->adjoint_ = adjoints_.emplace_back();
// emulate if constexpr
@@ -147,9 +155,9 @@ public:
}
// same as above at runtime
gradient_node<T, order> *emplace_active_multi_node(size_t n)
gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node(size_t n)
{
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
node->n_ = n;
node->adjoint_ = adjoints_.emplace_back();
if (n > 0) {
@@ -161,14 +169,17 @@ public:
/* manual reset button for all adjoints */
void zero_grad()
{
const T zero = T(0.0);
const RealType zero = RealType(0.0);
adjoints_.fill(zero);
}
// return type is an iterator
auto begin() { return gradient_nodes_.begin(); }
auto end() { return gradient_nodes_.end(); }
auto find(gradient_node<T, order> *node) { return gradient_nodes_.find(node); };
auto find(gradient_node<RealType, DerivativeOrder> *node)
{
return gradient_nodes_.find(node);
};
void add_checkpoint()
{
gradient_nodes_.add_checkpoint();
@@ -206,11 +217,14 @@ public:
}
// random acces
gradient_node<T, order> &operator[](size_t i) { return gradient_nodes_[i]; }
const gradient_node<T, order> &operator[](size_t i) const { return gradient_nodes_[i]; }
gradient_node<RealType, DerivativeOrder> &operator[](size_t i) { return gradient_nodes_[i]; }
const gradient_node<RealType, DerivativeOrder> &operator[](size_t i) const
{
return gradient_nodes_[i];
}
};
// class rvar;
template<typename T, size_t order> // no CRTP, just storage
template<typename RealType, size_t DerivativeOrder> // no CRTP, just storage
class gradient_node
{
/*
@@ -218,13 +232,15 @@ class gradient_node
* adjoints pointers to arguments aren't needed here
* */
public:
using adjoint_iterator = typename gradient_tape<T, order>::adjoint_iterator;
using derivatives_iterator = typename gradient_tape<T, order>::derivatives_iterator;
using argument_nodes_iterator = typename gradient_tape<T, order>::argument_nodes_iterator;
using adjoint_iterator = typename gradient_tape<RealType, DerivativeOrder>::adjoint_iterator;
using derivatives_iterator =
typename gradient_tape<RealType, DerivativeOrder>::derivatives_iterator;
using argument_nodes_iterator =
typename gradient_tape<RealType, DerivativeOrder>::argument_nodes_iterator;
private:
size_t n_;
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
/* these are iterators in case
* flat linear allocator is at capacity, and needs to allocate a new block of
* memory. */
@@ -233,8 +249,8 @@ private:
argument_nodes_iterator argument_nodes_;
public:
friend class gradient_tape<T, order>;
friend class rvar<T, order>;
friend class gradient_tape<RealType, DerivativeOrder>;
friend class rvar<RealType, DerivativeOrder>;
gradient_node() = default;
explicit gradient_node(const size_t n)
@@ -242,7 +258,10 @@ public:
, adjoint_(nullptr)
, derivatives_(nullptr)
{}
explicit gradient_node(const size_t n, T *adjoint, T *derivatives, rvar<T, order> **arguments)
explicit gradient_node(const size_t n,
RealType *adjoint,
RealType *derivatives,
rvar<RealType, DerivativeOrder> **arguments)
: n_(n)
, adjoint_(adjoint)
, derivatives_(derivatives)
@@ -263,7 +282,7 @@ public:
{
argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->update_adjoint_v(value);
};
void update_argument_ptr_at(size_t arg_id, gradient_node<T, order> *node_ptr)
void update_argument_ptr_at(size_t arg_id, gradient_node<RealType, DerivativeOrder> *node_ptr)
{
argument_nodes_[static_cast<ptrdiff_t>(arg_id)] = node_ptr;
}
@@ -275,7 +294,7 @@ public:
using boost::math::differentiation::reverse_mode::fabs;
using std::fabs;
if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<T>::epsilon())
if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<RealType>::epsilon())
return;
if (!argument_nodes_) // no arguments
@@ -294,21 +313,22 @@ public:
};
/****************************************************************************************************************/
template<typename T, size_t order>
inline gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
template<typename RealType, size_t DerivativeOrder>
inline gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
{
static BOOST_MATH_THREAD_LOCAL gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> tape;
static BOOST_MATH_THREAD_LOCAL gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE>
tape;
return tape;
}
template<typename T, size_t order = 1>
class rvar : public expression<T, order, rvar<T, order>>
template<typename RealType, size_t DerivativeOrder = 1>
class rvar : public expression<RealType, DerivativeOrder, rvar<RealType, DerivativeOrder>>
{
private:
using inner_t = rvar_t<T, order - 1>;
friend class gradient_node<T, order>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
friend class gradient_node<RealType, DerivativeOrder>;
inner_t value_;
gradient_node<T, order> *node_ = nullptr;
gradient_node<RealType, DerivativeOrder> *node_ = nullptr;
template<typename, size_t>
friend class rvar;
/*****************************************************************************************/
@@ -322,13 +342,13 @@ private:
/** @return value_ at rvar_t<T,current_order - 1>
*/
static auto &get(rvar<T, current_order> &v)
static auto &get(rvar<RealType, current_order> &v)
{
return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
}
/** @return const value_ at rvar_t<T,current_order - 1>
*/
static const auto &get(const rvar<T, current_order> &v)
static const auto &get(const rvar<RealType, current_order> &v)
{
return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
}
@@ -341,65 +361,69 @@ private:
{
/** @return value_ at rvar_t<T,target_order>
*/
static auto &get(rvar<T, target_order> &v) { return v; }
static auto &get(rvar<RealType, target_order> &v) { return v; }
/** @return const value_ at rvar_t<T,target_order>
*/
static const auto &get(const rvar<T, target_order> &v) { return v; }
static const auto &get(const rvar<RealType, target_order> &v) { return v; }
};
/*****************************************************************************************/
void make_leaf_node()
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
= get_active_tape<RealType, DerivativeOrder>();
node_ = tape.emplace_leaf_node();
}
void make_unary_node()
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
= get_active_tape<RealType, DerivativeOrder>();
node_ = tape.emplace_active_unary_node();
}
void make_multi_node(size_t n)
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
= get_active_tape<RealType, DerivativeOrder>();
node_ = tape.emplace_active_multi_node(n);
}
template<size_t n>
void make_multi_node()
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
= get_active_tape<RealType, DerivativeOrder>();
node_ = tape.template emplace_active_multi_node<n>();
}
template<typename E>
void make_rvar_from_expr(const expression<T, order, E> &expr)
void make_rvar_from_expr(const expression<RealType, DerivativeOrder, E> &expr)
{
make_multi_node<detail::count_rvars<E, order>>();
make_multi_node<detail::count_rvars<E, DerivativeOrder>>();
expr.template propagatex<0>(node_, inner_t(1.0));
}
T get_item_impl(std::true_type) const
RealType get_item_impl(std::true_type) const
{
return value_.get_item_impl(std::integral_constant<bool, (order - 1 > 1)>{});
return value_.get_item_impl(std::integral_constant<bool, (DerivativeOrder - 1 > 1)>{});
}
T get_item_impl(std::false_type) const { return value_; }
RealType get_item_impl(std::false_type) const { return value_; }
public:
using value_type = T;
static constexpr size_t order_v = order;
using value_type = RealType;
static constexpr size_t DerivativeOrder_v = DerivativeOrder;
rvar()
: value_()
{
make_leaf_node();
}
rvar(const T value)
rvar(const RealType value)
: value_(inner_t{value})
{
make_leaf_node();
}
rvar &operator=(T v)
rvar &operator=(RealType v)
{
value_ = inner_t(v);
if (node_ == nullptr) {
@@ -407,24 +431,24 @@ public:
}
return *this;
}
rvar(const rvar<T, order> &other) = default;
rvar &operator=(const rvar<T, order> &other) = default;
rvar(const rvar<RealType, DerivativeOrder> &other) = default;
rvar &operator=(const rvar<RealType, DerivativeOrder> &other) = default;
template<size_t arg_index>
void propagatex(gradient_node<T, order> *node, inner_t adj) const
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
node->update_derivative_v(arg_index, adj);
node->update_argument_ptr_at(arg_index, node_);
}
template<class E>
rvar(const expression<T, order, E> &expr)
rvar(const expression<RealType, DerivativeOrder, E> &expr)
{
value_ = expr.evaluate();
make_rvar_from_expr(expr);
}
template<class E>
rvar &operator=(const expression<T, order, E> &expr)
rvar &operator=(const expression<RealType, DerivativeOrder, E> &expr)
{
value_ = expr.evaluate();
make_rvar_from_expr(expr);
@@ -432,52 +456,52 @@ public:
}
/***************************************************************************************************/
template<class E>
rvar<T, order> &operator+=(const expression<T, order, E> &expr)
rvar<RealType, DerivativeOrder> &operator+=(const expression<RealType, DerivativeOrder, E> &expr)
{
*this = *this + expr;
return *this;
}
template<class E>
rvar<T, order> &operator*=(const expression<T, order, E> &expr)
rvar<RealType, DerivativeOrder> &operator*=(const expression<RealType, DerivativeOrder, E> &expr)
{
*this = *this * expr;
return *this;
}
template<class E>
rvar<T, order> &operator-=(const expression<T, order, E> &expr)
rvar<RealType, DerivativeOrder> &operator-=(const expression<RealType, DerivativeOrder, E> &expr)
{
*this = *this - expr;
return *this;
}
template<class E>
rvar<T, order> &operator/=(const expression<T, order, E> &expr)
rvar<RealType, DerivativeOrder> &operator/=(const expression<RealType, DerivativeOrder, E> &expr)
{
*this = *this / expr;
return *this;
}
/***************************************************************************************************/
rvar<T, order> &operator+=(const T &v)
rvar<RealType, DerivativeOrder> &operator+=(const RealType &v)
{
*this = *this + v;
return *this;
}
rvar<T, order> &operator*=(const T &v)
rvar<RealType, DerivativeOrder> &operator*=(const RealType &v)
{
*this = *this * v;
return *this;
}
rvar<T, order> &operator-=(const T &v)
rvar<RealType, DerivativeOrder> &operator-=(const RealType &v)
{
*this = *this - v;
return *this;
}
rvar<T, order> &operator/=(const T &v)
rvar<RealType, DerivativeOrder> &operator/=(const RealType &v)
{
*this = *this / v;
return *this;
@@ -490,7 +514,7 @@ public:
const inner_t &evaluate() const { return value_; };
inner_t &get_value() { return value_; };
explicit operator T() const { return item(); }
explicit operator RealType() const { return item(); }
explicit operator int() const { return static_cast<int>(item()); }
explicit operator long() const { return static_cast<long>(item()); }
@@ -503,23 +527,27 @@ public:
template<size_t N>
auto &get_value_at()
{
static_assert(N <= order, "Requested depth exceeds variable order.");
return get_value_at_impl<N, order>::get(*this);
static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
return get_value_at_impl<N, DerivativeOrder>::get(*this);
}
/** @brief same as above but const
*/
template<size_t N>
const auto &get_value_at() const
{
static_assert(N <= order, "Requested depth exceeds variable order.");
return get_value_at_impl<N, order>::get(*this);
static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
return get_value_at_impl<N, DerivativeOrder>::get(*this);
}
T item() const { return get_item_impl(std::integral_constant<bool, (order > 1)>{}); }
RealType item() const
{
return get_item_impl(std::integral_constant<bool, (DerivativeOrder > 1)>{});
}
void backward()
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
= get_active_tape<RealType, DerivativeOrder>();
auto it = tape.find(node_);
it->update_adjoint_v(inner_t(1.0));
while (it != tape.begin()) {
@@ -530,32 +558,32 @@ public:
}
};
template<typename T, size_t order>
std::ostream &operator<<(std::ostream &os, const rvar<T, order> var)
template<typename RealType, size_t DerivativeOrder>
std::ostream &operator<<(std::ostream &os, const rvar<RealType, DerivativeOrder> var)
{
os << "rvar<" << order << ">(" << var.item() << "," << var.adjoint() << ")";
os << "rvar<" << DerivativeOrder << ">(" << var.item() << "," << var.adjoint() << ")";
return os;
}
template<typename T, size_t order, typename E>
std::ostream &operator<<(std::ostream &os, const expression<T, order, E> &expr)
template<typename RealType, size_t DerivativeOrder, typename E>
std::ostream &operator<<(std::ostream &os, const expression<RealType, DerivativeOrder, E> &expr)
{
rvar<T, order> tmp = expr;
os << "rvar<" << order << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
rvar<RealType, DerivativeOrder> tmp = expr;
os << "rvar<" << DerivativeOrder << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
return os;
}
template<typename T, size_t order>
rvar<T, order> make_rvar(const T v)
template<typename RealType, size_t DerivativeOrder>
rvar<RealType, DerivativeOrder> make_rvar(const RealType v)
{
static_assert(order > 0, "rvar order must be >= 1");
return rvar<T, order>(v);
static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
return rvar<RealType, DerivativeOrder>(v);
}
template<typename T, size_t order, typename E>
rvar<T, order> make_rvar(const expression<T, order, E> &expr)
template<typename RealType, size_t DerivativeOrder, typename E>
rvar<RealType, DerivativeOrder> make_rvar(const expression<RealType, DerivativeOrder, E> &expr)
{
static_assert(order > 0, "rvar order must be >= 1");
return rvar<T, order>(expr);
static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
return rvar<RealType, DerivativeOrder>(expr);
}
namespace detail {
@@ -565,37 +593,24 @@ namespace detail {
* specialization for autodiffing through autodiff. i.e. being able to
* compute higher order grads
*/
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
struct grad_op_impl
{
std::vector<rvar<T, order - 1>> operator()(rvar<T, order> &f, std::vector<rvar<T, order> *> &x)
std::vector<rvar<RealType, DerivativeOrder - 1>> operator()(
rvar<RealType, DerivativeOrder> &f, std::vector<rvar<RealType, DerivativeOrder> *> &x)
{
auto &tape = get_active_tape<T, order>();
auto &tape = get_active_tape<RealType, DerivativeOrder>();
tape.zero_grad();
f.backward();
std::vector<rvar<T, order - 1>> gradient_vector;
std::vector<rvar<RealType, DerivativeOrder - 1>> gradient_vector;
gradient_vector.reserve(x.size());
for (auto xi : x) {
// make a new rvar<T,order-1> holding the adjoint value
for (auto &xi : x) {
gradient_vector.emplace_back(xi->adjoint());
}
return gradient_vector;
}
/*
std::vector<rvar_t<T, order - 1> *> operator()(rvar<T, order> &f,
std::vector<rvar<T, order> *> &x)
{
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
tape.zero_grad();
f.backward();
std::vector<rvar_t<T, order - 1> *> gradient_vector;
for (auto xi : x) {
gradient_vector.push_back(&(xi->adjoint()));
}
return gradient_vector;
}*/
};
/** @brief helper overload for grad implementation.
* @return vector<T> of gradients of the autodiff graph.
@@ -610,7 +625,7 @@ struct grad_op_impl<T, 1>
tape.zero_grad();
f.backward();
std::vector<T> gradient_vector;
for (auto xi : x) {
for (auto &xi : x) {
gradient_vector.push_back(xi->adjoint());
}
return gradient_vector;
@@ -621,52 +636,51 @@ struct grad_op_impl<T, 1>
* @return nested vector representing N-d tensor of
* higher order derivatives
*/
template<size_t N, typename T, size_t order_1, size_t order_2, typename Enable = void>
template<size_t N,
typename RealType,
size_t DerivativeOrder_1,
size_t DerivativeOrder_2,
typename Enable = void>
struct grad_nd_impl
{
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
auto operator()(rvar<RealType, DerivativeOrder_1> &f,
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
{
static_assert(N > 1, "N must be greater than 1 for this template");
auto current_grad = grad(f, x); // vector<rvar<T,order_1-1>> or vector<T>
auto current_grad = grad(f, x); // vector<rvar<T,DerivativeOrder_1-1>> or vector<T>
std::vector<decltype(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(current_grad[0], x))>
std::vector<decltype(grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(
current_grad[0], x))>
result;
result.reserve(current_grad.size());
for (auto &g_i : current_grad) {
result.push_back(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(g_i, x));
result.push_back(
grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(g_i, x));
}
return result;
}
/*
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
{
static_assert(N > 1, "N must be greater than 1 for this template");
auto current_grad = grad(f, x);
std::vector<decltype(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(*current_grad[0], x))>
result;
for (auto &g_i : current_grad) {
result.push_back(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(*g_i, x));
}
return result;
}*/
};
/** @brief spcialization for order = 1,
* @return vector<rvar<T,order_1-1>> gradients */
template<typename T, size_t order_1, size_t order_2>
struct grad_nd_impl<1, T, order_1, order_2>
* @return vector<rvar<T,DerivativeOrder_1-1>> gradients */
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
struct grad_nd_impl<1, RealType, DerivativeOrder_1, DerivativeOrder_2>
{
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x) { return grad(f, x); }
auto operator()(rvar<RealType, DerivativeOrder_1> &f,
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
{
return grad(f, x);
}
};
template<typename ptr>
struct rvar_order;
template<typename T, size_t order>
struct rvar_order<rvar<T, order> *>
template<typename RealType, size_t DerivativeOrder>
struct rvar_order<rvar<RealType, DerivativeOrder> *>
{
static constexpr size_t value = order;
static constexpr size_t value = DerivativeOrder;
};
} // namespace detail
@@ -676,52 +690,52 @@ struct rvar_order<rvar<T, order> *>
* @param f -> computational graph
* @param x -> variables gradients to record. Note ALL gradients of the graph
* are computed simultaneously, only the ones w.r.t. x are returned
* @return vector<rvar<T,order_1 - 1> of gradients. in the case of order_1 = 1
* rvar<T,order_1-1> decays to T
* @return vector<rvar<T,DerivativeOrder_1 - 1> of gradients. in the case of DerivativeOrder_1 = 1
* rvar<T,DerivativeOrder_1-1> decays to T
*
* safe to call recursively with grad(grad(grad...
*/
template<typename T, size_t order_1, size_t order_2>
auto grad(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
auto grad(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
{
static_assert(order_1 <= order_2,
static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
"variable differentiating w.r.t. must have order >= function order");
std::vector<rvar<T, order_1> *> xx;
std::vector<rvar<RealType, DerivativeOrder_1> *> xx;
for (auto &xi : x)
xx.push_back(&(xi->template get_value_at<order_1>()));
return detail::grad_op_impl<T, order_1>{}(f, xx);
xx.push_back(&(xi->template get_value_at<DerivativeOrder_1>()));
return detail::grad_op_impl<RealType, DerivativeOrder_1>{}(f, xx);
}
/** @brief variadic overload of above
*/
template<typename T, size_t order_1, typename First, typename... Other>
auto grad(rvar<T, order_1> &f, First first, Other... other)
template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
auto grad(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
{
constexpr size_t order_2 = detail::rvar_order<First>::value;
static_assert(order_1 <= order_2,
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
"variable differentiating w.r.t. must have order >= function order");
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
return grad(f, x_vec);
}
/** @brief computes hessian matrix of computational graph w.r.t.
* vector of variables x.
* @return std::vector<std::vector<rvar<T,order_1-2>> hessian matrix
* @return std::vector<std::vector<rvar<T,DerivativeOrder_1-2>> hessian matrix
* rvar<T,2> decays to T
*
* NOT recursion safe, cannot do hess(hess(
*/
template<typename T, size_t order_1, size_t order_2>
auto hess(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
auto hess(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
{
return detail::grad_nd_impl<2, T, order_1, order_2>{}(f, x);
return detail::grad_nd_impl<2, RealType, DerivativeOrder_1, DerivativeOrder_2>{}(f, x);
}
/** @brief variadic overload of above
*/
template<typename T, size_t order_1, typename First, typename... Other>
auto hess(rvar<T, order_1> &f, First first, Other... other)
template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
auto hess(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
{
constexpr size_t order_2 = detail::rvar_order<First>::value;
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
return hess(f, x_vec);
}
@@ -731,13 +745,15 @@ auto hess(rvar<T, order_1> &f, First first, Other... other)
*
* NOT recursively safe, cannot do grad_nd(grad_nd(... etc...
*/
template<size_t N, typename T, size_t order_1, size_t order_2>
auto grad_nd(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
template<size_t N, typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
auto grad_nd(rvar<RealType, DerivativeOrder_1> &f,
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
{
static_assert(order_1 >= N, "Function order must be at least N");
static_assert(order_2 >= order_1, "Variable order must be at least function order");
static_assert(DerivativeOrder_1 >= N, "Function order must be at least N");
static_assert(DerivativeOrder_2 >= DerivativeOrder_1,
"Variable order must be at least function order");
return detail::grad_nd_impl<N, T, order_1, order_2>()(f, x);
return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_2>()(f, x);
}
/** @brief variadic overload of above
@@ -745,11 +761,11 @@ auto grad_nd(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
template<size_t N, typename ftype, typename First, typename... Other>
auto grad_nd(ftype &f, First first, Other... other)
{
using T = typename ftype::value_type;
constexpr size_t order_1 = detail::rvar_order<ftype *>::value;
constexpr size_t order_2 = detail::rvar_order<First>::value;
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
return detail::grad_nd_impl<N, T, order_1, order_1>{}(f, x_vec);
using RealType = typename ftype::value_type;
constexpr size_t DerivativeOrder_1 = detail::rvar_order<ftype *>::value;
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_1>{}(f, x_vec);
}
} // namespace reverse_mode
} // namespace differentiation
@@ -758,10 +774,10 @@ auto grad_nd(ftype &f, First first, Other... other)
namespace std {
// copied from forward mode
template<typename T, size_t order>
class numeric_limits<boost::math::differentiation::reverse_mode::rvar<T, order>>
: public numeric_limits<
typename boost::math::differentiation::reverse_mode::rvar<T, order>::value_type>
template<typename RealType, size_t DerivativeOrder>
class numeric_limits<boost::math::differentiation::reverse_mode::rvar<RealType, DerivativeOrder>>
: public numeric_limits<typename boost::math::differentiation::reverse_mode::
rvar<RealType, DerivativeOrder>::value_type>
{};
} // namespace std
#endif

View File

@@ -12,19 +12,26 @@ namespace math {
namespace differentiation {
namespace reverse_mode {
/****************************************************************************************************************/
template<typename T, size_t order, typename LHS, typename RHS>
struct add_expr
: public abstract_binary_expression<T, order, LHS, RHS, add_expr<T, order, LHS, RHS>>
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
struct add_expr : public abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
add_expr<RealType, DerivativeOrder, LHS, RHS>>
{
/* @brief addition
* rvar+rvar
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
// Explicitly define constructor to forward to base class
explicit add_expr(const expression<T, order, LHS> &left_hand_expr,
const expression<T, order, RHS> &right_hand_expr)
: abstract_binary_expression<T, order, LHS, RHS, add_expr<T, order, LHS, RHS>>(
left_hand_expr, right_hand_expr)
explicit add_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
: abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
add_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
right_hand_expr)
{}
inner_t evaluate() const { return this->lhs.evaluate() + this->rhs.evaluate(); }
@@ -41,37 +48,51 @@ struct add_expr
return inner_t(1.0);
}
};
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct add_const_expr
: public abstract_unary_expression<T, order, ARG, add_const_expr<T, order, ARG>>
: public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
add_const_expr<RealType, DerivativeOrder, ARG>>
{
/* @brief
* rvar+float or float+rvar
* */
using inner_t = rvar_t<T, order - 1>;
explicit add_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, add_const_expr<T, order, ARG>>(arg_expr, v){};
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit add_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
add_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const { return this->arg.evaluate() + inner_t(this->constant); }
static const inner_t derivative(const inner_t & /*argv*/,
const inner_t & /*v*/,
const T & /*constant*/)
const RealType & /*constant*/)
{
return inner_t(1.0);
}
};
/****************************************************************************************************************/
template<typename T, size_t order, typename LHS, typename RHS>
struct mult_expr
: public abstract_binary_expression<T, order, LHS, RHS, mult_expr<T, order, LHS, RHS>>
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
struct mult_expr : public abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
mult_expr<RealType, DerivativeOrder, LHS, RHS>>
{
/* @brief multiplication
* rvar * rvar
* */
using inner_t = rvar_t<T, order - 1>;
explicit mult_expr(const expression<T, order, LHS> &left_hand_expr,
const expression<T, order, RHS> &right_hand_expr)
: abstract_binary_expression<T, order, LHS, RHS, mult_expr<T, order, LHS, RHS>>(
left_hand_expr, right_hand_expr)
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit mult_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
: abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
mult_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
right_hand_expr)
{}
inner_t evaluate() const { return this->lhs.evaluate() * this->rhs.evaluate(); };
@@ -88,40 +109,54 @@ struct mult_expr
return l;
};
};
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct mult_const_expr
: public abstract_unary_expression<T, order, ARG, mult_const_expr<T, order, ARG>>
: public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
mult_const_expr<RealType, DerivativeOrder, ARG>>
{
/* @brief
* rvar+float or float+rvar
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit mult_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, mult_const_expr<T, order, ARG>>(arg_expr, v){};
explicit mult_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
mult_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const { return this->arg.evaluate() * inner_t(this->constant); }
static const inner_t derivative(const inner_t & /*argv*/,
const inner_t & /*v*/,
const T &constant)
const RealType &constant)
{
return inner_t(constant);
}
};
/****************************************************************************************************************/
template<typename T, size_t order, typename LHS, typename RHS>
struct sub_expr
: public abstract_binary_expression<T, order, LHS, RHS, sub_expr<T, order, LHS, RHS>>
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
struct sub_expr : public abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
sub_expr<RealType, DerivativeOrder, LHS, RHS>>
{
/* @brief addition
* rvar-rvar
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
// Explicitly define constructor to forward to base class
explicit sub_expr(const expression<T, order, LHS> &left_hand_expr,
const expression<T, order, RHS> &right_hand_expr)
: abstract_binary_expression<T, order, LHS, RHS, sub_expr<T, order, LHS, RHS>>(
left_hand_expr, right_hand_expr)
explicit sub_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
: abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
sub_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
right_hand_expr)
{}
inner_t evaluate() const { return this->lhs.evaluate() - this->rhs.evaluate(); }
@@ -140,19 +175,26 @@ struct sub_expr
};
/****************************************************************************************************************/
template<typename T, size_t order, typename LHS, typename RHS>
struct div_expr
: public abstract_binary_expression<T, order, LHS, RHS, div_expr<T, order, LHS, RHS>>
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
struct div_expr : public abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
div_expr<RealType, DerivativeOrder, LHS, RHS>>
{
/* @brief multiplication
* rvar / rvar
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
// Explicitly define constructor to forward to base class
explicit div_expr(const expression<T, order, LHS> &left_hand_expr,
const expression<T, order, RHS> &right_hand_expr)
: abstract_binary_expression<T, order, LHS, RHS, div_expr<T, order, LHS, RHS>>(
left_hand_expr, right_hand_expr)
explicit div_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
: abstract_binary_expression<RealType,
DerivativeOrder,
LHS,
RHS,
div_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
right_hand_expr)
{}
inner_t evaluate() const { return this->lhs.evaluate() / this->rhs.evaluate(); };
@@ -160,182 +202,212 @@ struct div_expr
const inner_t &r,
const inner_t & /*v*/)
{
return static_cast<T>(1.0) / r;
return static_cast<RealType>(1.0) / r;
};
static const inner_t right_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/)
{
return -l / (r * r);
};
};
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct div_by_const_expr
: public abstract_unary_expression<T, order, ARG, div_by_const_expr<T, order, ARG>>
: public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
div_by_const_expr<RealType, DerivativeOrder, ARG>>
{
/* @brief
* rvar/float
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit div_by_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, div_by_const_expr<T, order, ARG>>(arg_expr, v){};
explicit div_by_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
div_by_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr,
v){};
inner_t evaluate() const { return this->arg.evaluate() / inner_t(this->constant); }
static const inner_t derivative(const inner_t & /*argv*/,
const inner_t & /*v*/,
const T &constant)
const RealType &constant)
{
return inner_t(1.0 / constant);
}
};
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct const_div_by_expr
: public abstract_unary_expression<T, order, ARG, const_div_by_expr<T, order, ARG>>
: public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
const_div_by_expr<RealType, DerivativeOrder, ARG>>
{
/** @brief
* float/rvar
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit const_div_by_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, const_div_by_expr<T, order, ARG>>(arg_expr, v){};
explicit const_div_by_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
const_div_by_expr<RealType, DerivativeOrder, ARG>>(arg_expr,
v){};
inner_t evaluate() const { return inner_t(this->constant) / this->arg.evaluate(); }
static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant)
static const inner_t derivative(const inner_t &argv,
const inner_t & /*v*/,
const RealType &constant)
{
return -inner_t{constant} / (argv * argv);
}
};
/****************************************************************************************************************/
template<typename T, size_t order, typename LHS, typename RHS>
mult_expr<T, order, LHS, RHS> operator*(const expression<T, order, LHS> &lhs,
const expression<T, order, RHS> &rhs)
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
mult_expr<RealType, DerivativeOrder, LHS, RHS> operator*(
const expression<RealType, DerivativeOrder, LHS> &lhs,
const expression<RealType, DerivativeOrder, RHS> &rhs)
{
return mult_expr<T, order, LHS, RHS>(lhs, rhs);
return mult_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
}
/** @brief type promotion is handled by casting the numeric type to
* the type inside expression. This is to avoid converting the
* entire tape in case you have something like double * rvar<float>
* */
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
mult_const_expr<T, order, ARG> operator*(const expression<T, order, ARG> &arg, const U &v)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
mult_const_expr<RealType1, DerivativeOrder, ARG> operator*(
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
{
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(v));
return mult_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
mult_const_expr<T, order, ARG> operator*(const U &v, const expression<T, order, ARG> &arg)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
mult_const_expr<RealType1, DerivativeOrder, ARG> operator*(
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
{
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(v));
return mult_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
/****************************************************************************************************************/
/* + */
template<typename T, size_t order, typename LHS, typename RHS>
add_expr<T, order, LHS, RHS> operator+(const expression<T, order, LHS> &lhs,
const expression<T, order, RHS> &rhs)
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
add_expr<RealType, DerivativeOrder, LHS, RHS> operator+(
const expression<RealType, DerivativeOrder, LHS> &lhs,
const expression<RealType, DerivativeOrder, RHS> &rhs)
{
return add_expr<T, order, LHS, RHS>(lhs, rhs);
return add_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
}
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
add_const_expr<T, order, ARG> operator+(const expression<T, order, ARG> &arg, const U &v)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
add_const_expr<RealType1, DerivativeOrder, ARG> operator+(
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
{
return add_const_expr<T, order, ARG>(arg, static_cast<T>(v));
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
add_const_expr<T, order, ARG> operator+(const U &v, const expression<T, order, ARG> &arg)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
add_const_expr<RealType1, DerivativeOrder, ARG> operator+(
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
{
return add_const_expr<T, order, ARG>(arg, static_cast<T>(v));
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
/****************************************************************************************************************/
/* - overload */
/** @brief
* negation (-1.0*rvar) */
template<typename T, size_t order, typename ARG>
mult_const_expr<T, order, ARG> operator-(const expression<T, order, ARG> &arg)
template<typename RealType, size_t DerivativeOrder, typename ARG>
mult_const_expr<RealType, DerivativeOrder, ARG> operator-(
const expression<RealType, DerivativeOrder, ARG> &arg)
{
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(-1.0));
return mult_const_expr<RealType, DerivativeOrder, ARG>(arg, static_cast<RealType>(-1.0));
}
/** @brief
* subtraction rvar-rvar */
template<typename T, size_t order, typename LHS, typename RHS>
sub_expr<T, order, LHS, RHS> operator-(const expression<T, order, LHS> &lhs,
const expression<T, order, RHS> &rhs)
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
sub_expr<RealType, DerivativeOrder, LHS, RHS> operator-(
const expression<RealType, DerivativeOrder, LHS> &lhs,
const expression<RealType, DerivativeOrder, RHS> &rhs)
{
return sub_expr<T, order, LHS, RHS>(lhs, rhs);
return sub_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
}
/** @brief
* subtraction float - rvar */
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
add_const_expr<T, order, ARG> operator-(const expression<T, order, ARG> &arg, const U &v)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
add_const_expr<RealType1, DerivativeOrder, ARG> operator-(
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
{
/* rvar - float = rvar + (-float) */
return add_const_expr<T, order, ARG>(arg, static_cast<T>(-v));
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(-v));
}
/** @brief
* subtraction float - rvar
* @return add_expr<neg_expr<ARG>>
*/
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
auto operator-(const U &v, const expression<T, order, ARG> &arg)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
auto operator-(const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
{
auto neg = -arg;
return neg + static_cast<T>(v);
return neg + static_cast<RealType1>(v);
}
/****************************************************************************************************************/
/* / */
template<typename T, size_t order, typename LHS, typename RHS>
div_expr<T, order, LHS, RHS> operator/(const expression<T, order, LHS> &lhs,
const expression<T, order, RHS> &rhs)
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
div_expr<RealType, DerivativeOrder, LHS, RHS> operator/(
const expression<RealType, DerivativeOrder, LHS> &lhs,
const expression<RealType, DerivativeOrder, RHS> &rhs)
{
return div_expr<T, order, LHS, RHS>(lhs, rhs);
return div_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
}
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
const_div_by_expr<T, order, ARG> operator/(const U &v, const expression<T, order, ARG> &arg)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
const_div_by_expr<RealType1, DerivativeOrder, ARG> operator/(
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
{
return const_div_by_expr<T, order, ARG>(arg, static_cast<T>(v));
return const_div_by_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
template<typename U,
typename T,
size_t order,
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
typename ARG,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
div_by_const_expr<T, order, ARG> operator/(const expression<T, order, ARG> &arg, const U &v)
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
div_by_const_expr<RealType1, DerivativeOrder, ARG> operator/(
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
{
return div_by_const_expr<T, order, ARG>(arg, static_cast<T>(v));
return div_by_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
}
} // namespace reverse_mode

View File

@@ -10,147 +10,153 @@ namespace boost {
namespace math {
namespace differentiation {
namespace reverse_mode {
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator==(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator==(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() == rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator==(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator==(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() == static_cast<T>(rhs);
return lhs.evaluate() == static_cast<RealType1>(rhs);
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator==(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator==(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs == rhs.evaluate();
}
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator!=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator!=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() != rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator!=(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator!=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() != rhs;
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator!=(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator!=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs != rhs.evaluate();
}
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator<(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator<(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() < rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator<(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator<(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() < rhs;
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator<(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator<(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs < rhs.evaluate();
}
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator>(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator>(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() > rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator>(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator>(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() > rhs;
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator>(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator>(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs > rhs.evaluate();
}
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator<=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator<=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() <= rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator<=(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator<=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() <= rhs;
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator<=(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator<=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs <= rhs.evaluate();
}
template<typename T, size_t order_1, size_t order_2, class E, class F>
bool operator>=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
bool operator>=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
{
return lhs.evaluate() >= rhs.evaluate();
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator>=(const expression<T, order, E> &lhs, const U &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator>=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
{
return lhs.evaluate() >= rhs;
}
template<typename U,
typename T,
size_t order,
class E,
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
bool operator>=(const U &lhs, const expression<T, order, E> &rhs)
template<typename RealType2,
typename RealType1,
size_t DerivativeOrder,
class ArgExpr,
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
bool operator>=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
{
return lhs >= rhs.evaluate();
}

View File

@@ -16,73 +16,84 @@ namespace math {
namespace differentiation {
namespace reverse_mode {
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erf_expr;
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erfc_expr;
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erf_inv_expr;
template<typename T, size_t order, typename ARG>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erfc_inv_expr;
template<typename T, size_t order, typename ARG>
erf_expr<T, order, ARG> erf(const expression<T, order, ARG> &arg)
template<typename RealType, size_t DerivativeOrder, typename ARG>
erf_expr<RealType, DerivativeOrder, ARG> erf(const expression<RealType, DerivativeOrder, ARG> &arg)
{
return erf_expr<T, order, ARG>(arg, 0.0);
return erf_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
}
template<typename T, size_t order, typename ARG>
erfc_expr<T, order, ARG> erfc(const expression<T, order, ARG> &arg)
template<typename RealType, size_t DerivativeOrder, typename ARG>
erfc_expr<RealType, DerivativeOrder, ARG> erfc(const expression<RealType, DerivativeOrder, ARG> &arg)
{
return erfc_expr<T, order, ARG>(arg, 0.0);
return erfc_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
}
template<typename T, size_t order, typename ARG>
erf_inv_expr<T, order, ARG> erf_inv(const expression<T, order, ARG> &arg)
template<typename RealType, size_t DerivativeOrder, typename ARG>
erf_inv_expr<RealType, DerivativeOrder, ARG> erf_inv(
const expression<RealType, DerivativeOrder, ARG> &arg)
{
return erf_inv_expr<T, order, ARG>(arg, 0.0);
return erf_inv_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
}
template<typename T, size_t order, typename ARG>
erfc_inv_expr<T, order, ARG> erfc_inv(const expression<T, order, ARG> &arg)
template<typename RealType, size_t DerivativeOrder, typename ARG>
erfc_inv_expr<RealType, DerivativeOrder, ARG> erfc_inv(
const expression<RealType, DerivativeOrder, ARG> &arg)
{
return erfc_inv_expr<T, order, ARG>(arg, 0.0);
return erfc_inv_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
}
template<typename T, size_t order, typename ARG>
struct erf_expr : public abstract_unary_expression<T, order, ARG, erf_expr<T, order, ARG>>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erf_expr : public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erf_expr<RealType, DerivativeOrder, ARG>>
{
/** @brief erf(x)
*
* d/dx erf(x) = 2*exp(x^2)/sqrt(pi)
*
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit erf_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, erf_expr<T, order, ARG>>(arg_expr, v){};
explicit erf_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr, const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erf_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const
{
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<(DerivativeOrder > 1)>(
[this](auto &&x) { return reverse_mode::erf(std::forward<decltype(x)>(x)); },
[this](auto &&x) { return boost::math::erf(std::forward<decltype(x)>(x)); },
this->arg.evaluate());
}
static const inner_t derivative(const inner_t &argv,
const inner_t & /*v*/,
const T & /*constant*/)
const RealType & /*constant*/)
{
BOOST_MATH_STD_USING
return static_cast<T>(2.0) * exp(-argv * argv) / sqrt(constants::pi<T>());
return static_cast<RealType>(2.0) * exp(-argv * argv) / sqrt(constants::pi<RealType>());
}
};
template<typename T, size_t order, typename ARG>
struct erfc_expr : public abstract_unary_expression<T, order, ARG, erfc_expr<T, order, ARG>>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erfc_expr : public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erfc_expr<RealType, DerivativeOrder, ARG>>
{
/** @brief erfc(x)
*
@@ -90,29 +101,35 @@ struct erfc_expr : public abstract_unary_expression<T, order, ARG, erfc_expr<T,
*
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit erfc_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, erfc_expr<T, order, ARG>>(arg_expr, v){};
explicit erfc_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr, const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erfc_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const
{
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
[this](auto &&x) { return reverse_mode::erfc(std::forward<decltype(x)>(x)); },
[this](auto &&x) { return boost::math::erfc(std::forward<decltype(x)>(x)); },
this->arg.evaluate());
}
static const inner_t derivative(const inner_t &argv,
const inner_t & /*v*/,
const T & /*constant*/)
const RealType & /*constant*/)
{
BOOST_MATH_STD_USING
return static_cast<T>(-2.0) * exp(-argv * argv) / sqrt(constants::pi<T>());
return static_cast<RealType>(-2.0) * exp(-argv * argv) / sqrt(constants::pi<RealType>());
}
};
template<typename T, size_t order, typename ARG>
struct erf_inv_expr : public abstract_unary_expression<T, order, ARG, erf_inv_expr<T, order, ARG>>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erf_inv_expr : public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erf_inv_expr<RealType, DerivativeOrder, ARG>>
{
/** @brief erf(x)
*
@@ -120,39 +137,47 @@ struct erf_inv_expr : public abstract_unary_expression<T, order, ARG, erf_inv_ex
*
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit erf_inv_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, erf_inv_expr<T, order, ARG>>(arg_expr, v){};
explicit erf_inv_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erf_inv_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const
{
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
[this](auto &&x) { return reverse_mode::erf_inv(std::forward<decltype(x)>(x)); },
[this](auto &&x) { return boost::math::erf_inv(std::forward<decltype(x)>(x)); },
this->arg.evaluate());
}
static const inner_t derivative(const inner_t &argv,
const inner_t & /*v*/,
const T & /*constant*/)
const RealType & /*constant*/)
{
BOOST_MATH_STD_USING
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
[](auto &&x) {
return static_cast<T>(0.5) * sqrt(constants::pi<T>())
return static_cast<RealType>(0.5) * sqrt(constants::pi<RealType>())
* reverse_mode::exp(
reverse_mode::pow(reverse_mode::erf_inv(x), static_cast<T>(2.0)));
reverse_mode::pow(reverse_mode::erf_inv(x), static_cast<RealType>(2.0)));
},
[](auto &&x) {
return static_cast<T>(0.5) * sqrt(constants::pi<T>())
* exp(pow(boost::math::erf_inv(x), static_cast<T>(2.0)));
return static_cast<RealType>(0.5) * sqrt(constants::pi<RealType>())
* exp(pow(boost::math::erf_inv(x), static_cast<RealType>(2.0)));
},
argv);
}
};
template<typename T, size_t order, typename ARG>
struct erfc_inv_expr : public abstract_unary_expression<T, order, ARG, erfc_inv_expr<T, order, ARG>>
template<typename RealType, size_t DerivativeOrder, typename ARG>
struct erfc_inv_expr
: public abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erfc_inv_expr<RealType, DerivativeOrder, ARG>>
{
/** @brief erfc(x)
*
@@ -160,32 +185,36 @@ struct erfc_inv_expr : public abstract_unary_expression<T, order, ARG, erfc_inv_
*
* */
using inner_t = rvar_t<T, order - 1>;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
explicit erfc_inv_expr(const expression<T, order, ARG> &arg_expr, const T v)
: abstract_unary_expression<T, order, ARG, erfc_inv_expr<T, order, ARG>>(arg_expr, v){};
explicit erfc_inv_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType v)
: abstract_unary_expression<RealType,
DerivativeOrder,
ARG,
erfc_inv_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
inner_t evaluate() const
{
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
[this](auto &&x) { return reverse_mode::erfc_inv(std::forward<decltype(x)>(x)); },
[this](auto &&x) { return boost::math::erfc_inv(std::forward<decltype(x)>(x)); },
this->arg.evaluate());
}
static const inner_t derivative(const inner_t &argv,
const inner_t & /*v*/,
const T & /*constant*/)
const RealType & /*constant*/)
{
BOOST_MATH_STD_USING
return detail::if_functional_dispatch<(order > 1)>(
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
[](auto &&x) {
return static_cast<T>(-0.5) * sqrt(constants::pi<T>())
* reverse_mode::exp(
reverse_mode::pow(reverse_mode::erfc_inv(x), static_cast<T>(2.0)));
return static_cast<RealType>(-0.5) * sqrt(constants::pi<RealType>())
* reverse_mode::exp(reverse_mode::pow(reverse_mode::erfc_inv(x),
static_cast<RealType>(2.0)));
},
[](auto &&x) {
return static_cast<T>(-0.5) * sqrt(constants::pi<T>())
* exp(pow(boost::math::erfc_inv(x), static_cast<T>(2.0)));
return static_cast<RealType>(-0.5) * sqrt(constants::pi<RealType>())
* exp(pow(boost::math::erfc_inv(x), static_cast<RealType>(2.0)));
},
argv);
}

View File

@@ -16,19 +16,24 @@ namespace reverse_mode {
struct expression_base
{};
template<typename T, size_t order, class derived_expression>
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
struct expression;
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
class rvar;
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
template<typename RealType,
size_t DerivativeOrder,
typename LHS,
typename RHS,
typename ConcreteBinaryOperation>
struct abstract_binary_expression;
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteUnaryOperation>
struct abstract_unary_expression;
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
class gradient_node; // forward declaration for tape
namespace detail {
@@ -59,50 +64,54 @@ struct count_rvar_impl
{
static constexpr std::size_t value = 0;
};
template<typename U, size_t order>
struct count_rvar_impl<rvar<U, order>, order>
template<typename RealType, size_t DerivativeOrder>
struct count_rvar_impl<rvar<RealType, DerivativeOrder>, DerivativeOrder>
{
static constexpr std::size_t value = 1;
};
template<typename T, std::size_t order>
struct count_rvar_impl<T,
order,
std::enable_if_t<has_binary_sub_types<T>::value
&& !std::is_same<T, rvar<typename T::value_type, order>>::value
&& !has_unary_sub_type<T>::value>>
template<typename RealType, std::size_t DerivativeOrder>
struct count_rvar_impl<
RealType,
DerivativeOrder,
std::enable_if_t<has_binary_sub_types<RealType>::value
&& !std::is_same<RealType, rvar<typename RealType::value_type, DerivativeOrder>>::value
&& !has_unary_sub_type<RealType>::value>>
{
static constexpr std::size_t value = count_rvar_impl<typename T::lhs_type, order>::value
+ count_rvar_impl<typename T::rhs_type, order>::value;
static constexpr std::size_t value
= count_rvar_impl<typename RealType::lhs_type, DerivativeOrder>::value
+ count_rvar_impl<typename RealType::rhs_type, DerivativeOrder>::value;
};
template<typename T, size_t order>
template<typename RealType, size_t DerivativeOrder>
struct count_rvar_impl<
T,
order,
typename std::enable_if_t<has_unary_sub_type<T>::value
&& !std::is_same<T, rvar<typename T::value_type, order>>::value
&& !has_binary_sub_types<T>::value>>
RealType,
DerivativeOrder,
typename std::enable_if_t<
has_unary_sub_type<RealType>::value
&& !std::is_same<RealType, rvar<typename RealType::value_type, DerivativeOrder>>::value
&& !has_binary_sub_types<RealType>::value>>
{
static constexpr std::size_t value = count_rvar_impl<typename T::arg_type, order>::value;
static constexpr std::size_t value
= count_rvar_impl<typename RealType::arg_type, DerivativeOrder>::value;
};
template<typename T, size_t order>
constexpr std::size_t count_rvars = detail::count_rvar_impl<T, order>::value;
template<typename RealType, size_t DerivativeOrder>
constexpr std::size_t count_rvars = detail::count_rvar_impl<RealType, DerivativeOrder>::value;
template<typename T>
struct is_expression : std::is_base_of<expression_base, typename std::decay<T>::type>
{};
template<typename T, size_t N>
template<typename RealType, size_t N>
struct rvar_type_impl
{
using type = rvar<T, N>;
using type = rvar<RealType, N>;
};
template<typename T>
struct rvar_type_impl<T, 0>
template<typename RealType>
struct rvar_type_impl<RealType, 0>
{
using type = T;
using type = RealType;
};
} // namespace detail
@@ -110,63 +119,69 @@ struct rvar_type_impl<T, 0>
template<typename T, size_t N>
using rvar_t = typename detail::rvar_type_impl<T, N>::type;
template<typename T, size_t order, class derived_expression>
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
struct expression : expression_base
{
/* @brief
* base expression class
* */
using value_type = T;
static constexpr size_t order_v = order;
using derived_type = derived_expression;
using value_type = RealType;
static constexpr size_t order_v = DerivativeOrder;
using derived_type = DerivedExpression;
static constexpr size_t num_literals = 0;
using inner_t = rvar_t<T, order - 1>;
inner_t evaluate() const { return static_cast<const derived_expression *>(this)->evaluate(); }
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
inner_t evaluate() const { return static_cast<const DerivedExpression *>(this)->evaluate(); }
template<size_t arg_index>
void propagatex(gradient_node<T, order> *node, inner_t adj) const
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
return static_cast<const derived_expression *>(this)->template propagatex<arg_index>(node,
adj);
}
return static_cast<const DerivedExpression *>(this)->template propagatex<arg_index>(node,
adj);
};
};
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
template<typename RealType,
size_t DerivativeOrder,
typename LHS,
typename RHS,
typename ConcreteBinaryOperation>
struct abstract_binary_expression
: public expression<T,
order,
abstract_binary_expression<T, order, LHS, RHS, concrete_binary_operation>>
: public expression<
RealType,
DerivativeOrder,
abstract_binary_expression<RealType, DerivativeOrder, LHS, RHS, ConcreteBinaryOperation>>
{
using lhs_type = LHS;
using rhs_type = RHS;
using value_type = T;
using inner_t = rvar_t<T, order - 1>;
using value_type = RealType;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
const lhs_type lhs;
const rhs_type rhs;
explicit abstract_binary_expression(const expression<T, order, LHS> &left_hand_expr,
const expression<T, order, RHS> &right_hand_expr)
explicit abstract_binary_expression(
const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
: lhs(static_cast<const LHS &>(left_hand_expr))
, rhs(static_cast<const RHS &>(right_hand_expr)){};
inner_t evaluate() const
{
return static_cast<const concrete_binary_operation *>(this)->evaluate();
return static_cast<const ConcreteBinaryOperation *>(this)->evaluate();
};
template<size_t arg_index>
void propagatex(gradient_node<T, order> *node, inner_t adj) const
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
const inner_t lv = lhs.evaluate();
const inner_t rv = rhs.evaluate();
const inner_t v = evaluate();
const inner_t partial_l = concrete_binary_operation::left_derivative(lv, rv, v);
const inner_t partial_r = concrete_binary_operation::right_derivative(lv, rv, v);
inner_t lv = lhs.evaluate();
inner_t rv = rhs.evaluate();
inner_t v = evaluate();
inner_t partial_l = ConcreteBinaryOperation::left_derivative(lv, rv, v);
inner_t partial_r = ConcreteBinaryOperation::right_derivative(lv, rv, v);
constexpr size_t num_lhs_args = detail::count_rvars<LHS, order>;
constexpr size_t num_rhs_args = detail::count_rvars<RHS, order>;
constexpr size_t num_lhs_args = detail::count_rvars<LHS, DerivativeOrder>;
constexpr size_t num_rhs_args = detail::count_rvars<RHS, DerivativeOrder>;
propagate_lhs<num_lhs_args, arg_index>(node, adj * partial_l);
propagate_rhs<num_rhs_args, arg_index + num_lhs_args>(node, adj * partial_r);
@@ -178,7 +193,7 @@ private:
template<std::size_t num_args,
std::size_t arg_index_,
typename std::enable_if<(num_args > 0), int>::type = 0>
void propagate_lhs(gradient_node<T, order> *node, inner_t adj) const
void propagate_lhs(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
lhs.template propagatex<arg_index_>(node, adj);
}
@@ -186,13 +201,13 @@ private:
template<std::size_t num_args,
std::size_t arg_index_,
typename std::enable_if<(num_args == 0), int>::type = 0>
void propagate_lhs(gradient_node<T, order> *, inner_t) const
void propagate_lhs(gradient_node<RealType, DerivativeOrder> *, inner_t) const
{}
template<std::size_t num_args,
std::size_t arg_index_,
typename std::enable_if<(num_args > 0), int>::type = 0>
void propagate_rhs(gradient_node<T, order> *node, inner_t adj) const
void propagate_rhs(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
rhs.template propagatex<arg_index_>(node, adj);
}
@@ -200,32 +215,37 @@ private:
template<std::size_t num_args,
std::size_t arg_index_,
typename std::enable_if<(num_args == 0), int>::type = 0>
void propagate_rhs(gradient_node<T, order> *, inner_t) const
void propagate_rhs(gradient_node<RealType, DerivativeOrder> *, inner_t) const
{}
};
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteUnaryOperation>
struct abstract_unary_expression
: public expression<T, order, abstract_unary_expression<T, order, ARG, concrete_unary_operation>>
: public expression<
RealType,
DerivativeOrder,
abstract_unary_expression<RealType, DerivativeOrder, ARG, ConcreteUnaryOperation>>
{
using arg_type = ARG;
using value_type = T;
using inner_t = rvar_t<T, order - 1>;
using value_type = RealType;
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
const arg_type arg;
const T constant;
explicit abstract_unary_expression(const expression<T, order, ARG> &arg_expr, const T &constant)
const RealType constant;
explicit abstract_unary_expression(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
const RealType &constant)
: arg(static_cast<const ARG &>(arg_expr))
, constant(constant){};
inner_t evaluate() const
{
return static_cast<const concrete_unary_operation *>(this)->evaluate();
return static_cast<const ConcreteUnaryOperation *>(this)->evaluate();
};
template<size_t arg_index>
void propagatex(gradient_node<T, order> *node, inner_t adj) const
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
{
inner_t argv = arg.evaluate();
inner_t v = evaluate();
inner_t partial_arg = concrete_unary_operation::derivative(argv, v, constant);
inner_t partial_arg = ConcreteUnaryOperation::derivative(argv, v, constant);
arg.template propagatex<arg_index>(node, adj * partial_arg);
}

View File

@@ -188,7 +188,7 @@ public:
bool operator!() const noexcept { return storage_ == nullptr; }
};
/* memory management helps for tape */
template<typename T, size_t buffer_size>
template<typename RealType, size_t buffer_size>
class flat_linear_allocator
{
/** @brief basically a vector<array<T*, size>>
@@ -198,8 +198,8 @@ class flat_linear_allocator
public:
// store vector of unique pointers to arrays
// to avoid vector reference invalidation
using buffer_type = std::array<T, buffer_size>;
using buffer_ptr = std::unique_ptr<std::array<T, buffer_size>>;
using buffer_type = std::array<RealType, buffer_size>;
using buffer_ptr = std::unique_ptr<std::array<RealType, buffer_size>>;
private:
std::vector<buffer_ptr> data_;
@@ -207,14 +207,16 @@ private:
std::vector<size_t> checkpoints_; //{0};
public:
friend class flat_linear_allocator_iterator<flat_linear_allocator<T, buffer_size>, buffer_size>;
friend class flat_linear_allocator_iterator<const flat_linear_allocator<T, buffer_size>,
friend class flat_linear_allocator_iterator<flat_linear_allocator<RealType, buffer_size>,
buffer_size>;
using value_type = T;
friend class flat_linear_allocator_iterator<const flat_linear_allocator<RealType, buffer_size>,
buffer_size>;
using value_type = RealType;
using iterator
= flat_linear_allocator_iterator<flat_linear_allocator<T, buffer_size>, buffer_size>;
= flat_linear_allocator_iterator<flat_linear_allocator<RealType, buffer_size>, buffer_size>;
using const_iterator
= flat_linear_allocator_iterator<const flat_linear_allocator<T, buffer_size>, buffer_size>;
= flat_linear_allocator_iterator<const flat_linear_allocator<RealType, buffer_size>,
buffer_size>;
size_t buffer_id() const noexcept { return total_size_ / buffer_size; }
size_t item_id() const noexcept { return total_size_ % buffer_size; }
@@ -242,7 +244,7 @@ public:
for (size_t i = 0; i < total_size_; ++i) {
size_t bid = i / buffer_size;
size_t iid = i % buffer_size;
(*data_[bid])[iid].~T();
(*data_[bid])[iid].~RealType();
}
}
/** @brief
@@ -277,7 +279,7 @@ public:
void rewind_to_last_checkpoint() { total_size_ = checkpoints_.back(); }
void rewind_to_checkpoint_at(size_t index) { total_size_ = checkpoints_[index]; }
void fill(const T &val)
void fill(const RealType &val)
{
for (size_t i = 0; i < total_size_; ++i) {
size_t bid = i / buffer_size;
@@ -296,8 +298,8 @@ public:
size_t bid = buffer_id();
size_t iid = item_id();
T *ptr = &(*data_[bid])[iid];
new (ptr) T();
RealType *ptr = &(*data_[bid])[iid];
new (ptr) RealType();
++total_size_;
return iterator(this, total_size_ - 1);
};
@@ -312,8 +314,8 @@ public:
}
BOOST_MATH_ASSERT(buffer_id() < data_.size());
BOOST_MATH_ASSERT(item_id() < buffer_size);
T *ptr = &(*data_[buffer_id()])[item_id()];
new (ptr) T(std::forward<Args>(args)...);
RealType *ptr = &(*data_[buffer_id()])[item_id()];
new (ptr) RealType(std::forward<Args>(args)...);
++total_size_;
return iterator(this, total_size_ - 1);
}
@@ -325,27 +327,27 @@ public:
size_t bid = buffer_id();
size_t iid = item_id();
if (iid + n < buffer_size) {
T *ptr = &(*data_[bid])[iid];
RealType *ptr = &(*data_[bid])[iid];
for (size_t i = 0; i < n; ++i) {
new (ptr + i) T();
new (ptr + i) RealType();
}
total_size_ += n;
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
} else {
size_t allocs_in_curr_buffer = buffer_size - iid;
size_t allocs_in_next_buffer = n - (buffer_size - iid);
T *ptr = &(*data_[bid])[iid];
RealType *ptr = &(*data_[bid])[iid];
for (size_t i = 0; i < allocs_in_curr_buffer; ++i) {
new (ptr + i) T();
new (ptr + i) RealType();
}
allocate_buffer();
bid = data_.size() - 1;
iid = 0;
total_size_ += n;
T *ptr2 = &(*data_[bid])[iid];
RealType *ptr2 = &(*data_[bid])[iid];
for (size_t i = 0; i < allocs_in_next_buffer; i++) {
new (ptr2 + i) T();
new (ptr2 + i) RealType();
}
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
}
@@ -358,27 +360,27 @@ public:
size_t bid = buffer_id();
size_t iid = item_id();
if (iid + n < buffer_size) {
T *ptr = &(*data_[bid])[iid];
RealType *ptr = &(*data_[bid])[iid];
for (size_t i = 0; i < n; ++i) {
new (ptr + i) T();
new (ptr + i) RealType();
}
total_size_ += n;
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
} else {
size_t allocs_in_curr_buffer = buffer_size - iid;
size_t allocs_in_next_buffer = n - (buffer_size - iid);
T *ptr = &(*data_[bid])[iid];
RealType *ptr = &(*data_[bid])[iid];
for (size_t i = 0; i < allocs_in_curr_buffer; ++i) {
new (ptr + i) T();
new (ptr + i) RealType();
}
allocate_buffer();
bid = data_.size() - 1;
iid = 0;
total_size_ += n;
T *ptr2 = &(*data_[bid])[iid];
RealType *ptr2 = &(*data_[bid])[iid];
for (size_t i = 0; i < allocs_in_next_buffer; i++) {
new (ptr2 + i) T();
new (ptr2 + i) RealType();
}
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
}
@@ -405,19 +407,19 @@ public:
/** @brief searches for item in allocator
* only used to find gradient nodes for propagation */
iterator find(const T *const item)
iterator find(const RealType *const item)
{
return std::find_if(begin(), end(), [&](const T &val) { return &val == item; });
return std::find_if(begin(), end(), [&](const RealType &val) { return &val == item; });
}
/** @brief vector like access,
* currently unused anywhere but very useful for debugging
*/
T &operator[](std::size_t i)
RealType &operator[](std::size_t i)
{
BOOST_MATH_ASSERT(i < total_size_);
return (*data_[i / buffer_size])[i % buffer_size];
}
const T &operator[](std::size_t i) const
const RealType &operator[](std::size_t i) const
{
BOOST_MATH_ASSERT(i < total_size_);
return (*data_[i / buffer_size])[i % buffer_size];