mirror of
https://github.com/boostorg/math.git
synced 2026-01-19 04:22:09 +00:00
modified template variable names, changed copies to references
This commit is contained in:
@@ -35,40 +35,46 @@ namespace differentiation {
|
||||
namespace reverse_mode {
|
||||
|
||||
/* forward declarations for utitlity functions */
|
||||
template<typename T, size_t order, class derived_expression>
|
||||
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
|
||||
struct expression;
|
||||
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
class rvar;
|
||||
|
||||
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
|
||||
template<typename RealType,
|
||||
size_t DerivativeOrder,
|
||||
typename LHS,
|
||||
typename RHS,
|
||||
typename ConcreteBinaryOperation>
|
||||
struct abstract_binary_expression;
|
||||
|
||||
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteBinaryOperation>
|
||||
struct abstract_unary_expression;
|
||||
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
class gradient_node; // forward declaration for tape
|
||||
// manages nodes in computational graph
|
||||
template<typename T, size_t order, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
|
||||
template<typename RealType, size_t DerivativeOrder, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
|
||||
class gradient_tape
|
||||
{
|
||||
/** @brief tape (graph) management class for autodiff
|
||||
* holds all the data structures for autodiff */
|
||||
private:
|
||||
/* type decays to order - 1 to support higher order derivatives */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
/* adjoints are the overall derivative, and derivatives are the "local"
|
||||
* derivative */
|
||||
detail::flat_linear_allocator<inner_t, buffer_size> adjoints_;
|
||||
detail::flat_linear_allocator<inner_t, buffer_size> derivatives_;
|
||||
detail::flat_linear_allocator<gradient_node<T, order>, buffer_size> gradient_nodes_;
|
||||
detail::flat_linear_allocator<gradient_node<T, order> *, buffer_size> argument_nodes_;
|
||||
detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>, buffer_size>
|
||||
gradient_nodes_;
|
||||
detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *, buffer_size>
|
||||
argument_nodes_;
|
||||
|
||||
// compile time check if emplace_back calls on zero
|
||||
template<size_t n>
|
||||
gradient_node<T, order> *fill_node_at_compile_time(std::true_type,
|
||||
gradient_node<T, order> *node_ptr)
|
||||
gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
|
||||
std::true_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
|
||||
{
|
||||
node_ptr->derivatives_ = derivatives_.template emplace_back_n<n>();
|
||||
node_ptr->argument_nodes_ = argument_nodes_.template emplace_back_n<n>();
|
||||
@@ -76,8 +82,8 @@ private:
|
||||
}
|
||||
|
||||
template<size_t n>
|
||||
gradient_node<T, order> *fill_node_at_compile_time(std::false_type,
|
||||
gradient_node<T, order> *node_ptr)
|
||||
gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
|
||||
std::false_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
|
||||
{
|
||||
node_ptr->derivatives_ = nullptr;
|
||||
node_ptr->argument_adjoints_ = nullptr;
|
||||
@@ -94,9 +100,11 @@ public:
|
||||
using derivatives_iterator =
|
||||
typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
|
||||
using gradient_nodes_iterator =
|
||||
typename detail::flat_linear_allocator<gradient_node<T, order>, buffer_size>::iterator;
|
||||
typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>,
|
||||
buffer_size>::iterator;
|
||||
using argument_nodes_iterator =
|
||||
typename detail::flat_linear_allocator<gradient_node<T, order> *, buffer_size>::iterator;
|
||||
typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *,
|
||||
buffer_size>::iterator;
|
||||
|
||||
gradient_tape() { clear(); };
|
||||
|
||||
@@ -114,9 +122,9 @@ public:
|
||||
}
|
||||
|
||||
// no derivatives or arguments
|
||||
gradient_node<T, order> *emplace_leaf_node()
|
||||
gradient_node<RealType, DerivativeOrder> *emplace_leaf_node()
|
||||
{
|
||||
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
|
||||
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
|
||||
node->adjoint_ = adjoints_.emplace_back();
|
||||
node->derivatives_ = derivatives_iterator(); // nullptr;
|
||||
node->argument_nodes_ = argument_nodes_iterator(); // nullptr;
|
||||
@@ -125,9 +133,9 @@ public:
|
||||
};
|
||||
|
||||
// single argument, single derivative
|
||||
gradient_node<T, order> *emplace_active_unary_node()
|
||||
gradient_node<RealType, DerivativeOrder> *emplace_active_unary_node()
|
||||
{
|
||||
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
|
||||
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
|
||||
node->n_ = 1;
|
||||
node->adjoint_ = adjoints_.emplace_back();
|
||||
node->derivatives_ = derivatives_.emplace_back();
|
||||
@@ -137,9 +145,9 @@ public:
|
||||
|
||||
// arbitrary number of arguments/derivatives (compile time)
|
||||
template<size_t n>
|
||||
gradient_node<T, order> *emplace_active_multi_node()
|
||||
gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node()
|
||||
{
|
||||
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
|
||||
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
|
||||
node->n_ = n;
|
||||
node->adjoint_ = adjoints_.emplace_back();
|
||||
// emulate if constexpr
|
||||
@@ -147,9 +155,9 @@ public:
|
||||
}
|
||||
|
||||
// same as above at runtime
|
||||
gradient_node<T, order> *emplace_active_multi_node(size_t n)
|
||||
gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node(size_t n)
|
||||
{
|
||||
gradient_node<T, order> *node = &*gradient_nodes_.emplace_back();
|
||||
gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
|
||||
node->n_ = n;
|
||||
node->adjoint_ = adjoints_.emplace_back();
|
||||
if (n > 0) {
|
||||
@@ -161,14 +169,17 @@ public:
|
||||
/* manual reset button for all adjoints */
|
||||
void zero_grad()
|
||||
{
|
||||
const T zero = T(0.0);
|
||||
const RealType zero = RealType(0.0);
|
||||
adjoints_.fill(zero);
|
||||
}
|
||||
|
||||
// return type is an iterator
|
||||
auto begin() { return gradient_nodes_.begin(); }
|
||||
auto end() { return gradient_nodes_.end(); }
|
||||
auto find(gradient_node<T, order> *node) { return gradient_nodes_.find(node); };
|
||||
auto find(gradient_node<RealType, DerivativeOrder> *node)
|
||||
{
|
||||
return gradient_nodes_.find(node);
|
||||
};
|
||||
void add_checkpoint()
|
||||
{
|
||||
gradient_nodes_.add_checkpoint();
|
||||
@@ -206,11 +217,14 @@ public:
|
||||
}
|
||||
|
||||
// random acces
|
||||
gradient_node<T, order> &operator[](size_t i) { return gradient_nodes_[i]; }
|
||||
const gradient_node<T, order> &operator[](size_t i) const { return gradient_nodes_[i]; }
|
||||
gradient_node<RealType, DerivativeOrder> &operator[](size_t i) { return gradient_nodes_[i]; }
|
||||
const gradient_node<RealType, DerivativeOrder> &operator[](size_t i) const
|
||||
{
|
||||
return gradient_nodes_[i];
|
||||
}
|
||||
};
|
||||
// class rvar;
|
||||
template<typename T, size_t order> // no CRTP, just storage
|
||||
template<typename RealType, size_t DerivativeOrder> // no CRTP, just storage
|
||||
class gradient_node
|
||||
{
|
||||
/*
|
||||
@@ -218,13 +232,15 @@ class gradient_node
|
||||
* adjoints pointers to arguments aren't needed here
|
||||
* */
|
||||
public:
|
||||
using adjoint_iterator = typename gradient_tape<T, order>::adjoint_iterator;
|
||||
using derivatives_iterator = typename gradient_tape<T, order>::derivatives_iterator;
|
||||
using argument_nodes_iterator = typename gradient_tape<T, order>::argument_nodes_iterator;
|
||||
using adjoint_iterator = typename gradient_tape<RealType, DerivativeOrder>::adjoint_iterator;
|
||||
using derivatives_iterator =
|
||||
typename gradient_tape<RealType, DerivativeOrder>::derivatives_iterator;
|
||||
using argument_nodes_iterator =
|
||||
typename gradient_tape<RealType, DerivativeOrder>::argument_nodes_iterator;
|
||||
|
||||
private:
|
||||
size_t n_;
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
/* these are iterators in case
|
||||
* flat linear allocator is at capacity, and needs to allocate a new block of
|
||||
* memory. */
|
||||
@@ -233,8 +249,8 @@ private:
|
||||
argument_nodes_iterator argument_nodes_;
|
||||
|
||||
public:
|
||||
friend class gradient_tape<T, order>;
|
||||
friend class rvar<T, order>;
|
||||
friend class gradient_tape<RealType, DerivativeOrder>;
|
||||
friend class rvar<RealType, DerivativeOrder>;
|
||||
|
||||
gradient_node() = default;
|
||||
explicit gradient_node(const size_t n)
|
||||
@@ -242,7 +258,10 @@ public:
|
||||
, adjoint_(nullptr)
|
||||
, derivatives_(nullptr)
|
||||
{}
|
||||
explicit gradient_node(const size_t n, T *adjoint, T *derivatives, rvar<T, order> **arguments)
|
||||
explicit gradient_node(const size_t n,
|
||||
RealType *adjoint,
|
||||
RealType *derivatives,
|
||||
rvar<RealType, DerivativeOrder> **arguments)
|
||||
: n_(n)
|
||||
, adjoint_(adjoint)
|
||||
, derivatives_(derivatives)
|
||||
@@ -263,7 +282,7 @@ public:
|
||||
{
|
||||
argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->update_adjoint_v(value);
|
||||
};
|
||||
void update_argument_ptr_at(size_t arg_id, gradient_node<T, order> *node_ptr)
|
||||
void update_argument_ptr_at(size_t arg_id, gradient_node<RealType, DerivativeOrder> *node_ptr)
|
||||
{
|
||||
argument_nodes_[static_cast<ptrdiff_t>(arg_id)] = node_ptr;
|
||||
}
|
||||
@@ -275,7 +294,7 @@ public:
|
||||
|
||||
using boost::math::differentiation::reverse_mode::fabs;
|
||||
using std::fabs;
|
||||
if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<T>::epsilon())
|
||||
if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<RealType>::epsilon())
|
||||
return;
|
||||
|
||||
if (!argument_nodes_) // no arguments
|
||||
@@ -294,21 +313,22 @@ public:
|
||||
};
|
||||
|
||||
/****************************************************************************************************************/
|
||||
template<typename T, size_t order>
|
||||
inline gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
inline gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
|
||||
{
|
||||
static BOOST_MATH_THREAD_LOCAL gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> tape;
|
||||
static BOOST_MATH_THREAD_LOCAL gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE>
|
||||
tape;
|
||||
return tape;
|
||||
}
|
||||
|
||||
template<typename T, size_t order = 1>
|
||||
class rvar : public expression<T, order, rvar<T, order>>
|
||||
template<typename RealType, size_t DerivativeOrder = 1>
|
||||
class rvar : public expression<RealType, DerivativeOrder, rvar<RealType, DerivativeOrder>>
|
||||
{
|
||||
private:
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
friend class gradient_node<T, order>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
friend class gradient_node<RealType, DerivativeOrder>;
|
||||
inner_t value_;
|
||||
gradient_node<T, order> *node_ = nullptr;
|
||||
gradient_node<RealType, DerivativeOrder> *node_ = nullptr;
|
||||
template<typename, size_t>
|
||||
friend class rvar;
|
||||
/*****************************************************************************************/
|
||||
@@ -322,13 +342,13 @@ private:
|
||||
|
||||
/** @return value_ at rvar_t<T,current_order - 1>
|
||||
*/
|
||||
static auto &get(rvar<T, current_order> &v)
|
||||
static auto &get(rvar<RealType, current_order> &v)
|
||||
{
|
||||
return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
|
||||
}
|
||||
/** @return const value_ at rvar_t<T,current_order - 1>
|
||||
*/
|
||||
static const auto &get(const rvar<T, current_order> &v)
|
||||
static const auto &get(const rvar<RealType, current_order> &v)
|
||||
{
|
||||
return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
|
||||
}
|
||||
@@ -341,65 +361,69 @@ private:
|
||||
{
|
||||
/** @return value_ at rvar_t<T,target_order>
|
||||
*/
|
||||
static auto &get(rvar<T, target_order> &v) { return v; }
|
||||
static auto &get(rvar<RealType, target_order> &v) { return v; }
|
||||
/** @return const value_ at rvar_t<T,target_order>
|
||||
*/
|
||||
static const auto &get(const rvar<T, target_order> &v) { return v; }
|
||||
static const auto &get(const rvar<RealType, target_order> &v) { return v; }
|
||||
};
|
||||
/*****************************************************************************************/
|
||||
void make_leaf_node()
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
|
||||
= get_active_tape<RealType, DerivativeOrder>();
|
||||
node_ = tape.emplace_leaf_node();
|
||||
}
|
||||
|
||||
void make_unary_node()
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
|
||||
= get_active_tape<RealType, DerivativeOrder>();
|
||||
node_ = tape.emplace_active_unary_node();
|
||||
}
|
||||
|
||||
void make_multi_node(size_t n)
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
|
||||
= get_active_tape<RealType, DerivativeOrder>();
|
||||
node_ = tape.emplace_active_multi_node(n);
|
||||
}
|
||||
|
||||
template<size_t n>
|
||||
void make_multi_node()
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
|
||||
= get_active_tape<RealType, DerivativeOrder>();
|
||||
node_ = tape.template emplace_active_multi_node<n>();
|
||||
}
|
||||
|
||||
template<typename E>
|
||||
void make_rvar_from_expr(const expression<T, order, E> &expr)
|
||||
void make_rvar_from_expr(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
make_multi_node<detail::count_rvars<E, order>>();
|
||||
make_multi_node<detail::count_rvars<E, DerivativeOrder>>();
|
||||
expr.template propagatex<0>(node_, inner_t(1.0));
|
||||
}
|
||||
T get_item_impl(std::true_type) const
|
||||
RealType get_item_impl(std::true_type) const
|
||||
{
|
||||
return value_.get_item_impl(std::integral_constant<bool, (order - 1 > 1)>{});
|
||||
return value_.get_item_impl(std::integral_constant<bool, (DerivativeOrder - 1 > 1)>{});
|
||||
}
|
||||
|
||||
T get_item_impl(std::false_type) const { return value_; }
|
||||
RealType get_item_impl(std::false_type) const { return value_; }
|
||||
|
||||
public:
|
||||
using value_type = T;
|
||||
static constexpr size_t order_v = order;
|
||||
using value_type = RealType;
|
||||
static constexpr size_t DerivativeOrder_v = DerivativeOrder;
|
||||
rvar()
|
||||
: value_()
|
||||
{
|
||||
make_leaf_node();
|
||||
}
|
||||
rvar(const T value)
|
||||
rvar(const RealType value)
|
||||
: value_(inner_t{value})
|
||||
{
|
||||
make_leaf_node();
|
||||
}
|
||||
|
||||
rvar &operator=(T v)
|
||||
rvar &operator=(RealType v)
|
||||
{
|
||||
value_ = inner_t(v);
|
||||
if (node_ == nullptr) {
|
||||
@@ -407,24 +431,24 @@ public:
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
rvar(const rvar<T, order> &other) = default;
|
||||
rvar &operator=(const rvar<T, order> &other) = default;
|
||||
rvar(const rvar<RealType, DerivativeOrder> &other) = default;
|
||||
rvar &operator=(const rvar<RealType, DerivativeOrder> &other) = default;
|
||||
|
||||
template<size_t arg_index>
|
||||
void propagatex(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
node->update_derivative_v(arg_index, adj);
|
||||
node->update_argument_ptr_at(arg_index, node_);
|
||||
}
|
||||
|
||||
template<class E>
|
||||
rvar(const expression<T, order, E> &expr)
|
||||
rvar(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
value_ = expr.evaluate();
|
||||
make_rvar_from_expr(expr);
|
||||
}
|
||||
template<class E>
|
||||
rvar &operator=(const expression<T, order, E> &expr)
|
||||
rvar &operator=(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
value_ = expr.evaluate();
|
||||
make_rvar_from_expr(expr);
|
||||
@@ -432,52 +456,52 @@ public:
|
||||
}
|
||||
/***************************************************************************************************/
|
||||
template<class E>
|
||||
rvar<T, order> &operator+=(const expression<T, order, E> &expr)
|
||||
rvar<RealType, DerivativeOrder> &operator+=(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
*this = *this + expr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class E>
|
||||
rvar<T, order> &operator*=(const expression<T, order, E> &expr)
|
||||
rvar<RealType, DerivativeOrder> &operator*=(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
*this = *this * expr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class E>
|
||||
rvar<T, order> &operator-=(const expression<T, order, E> &expr)
|
||||
rvar<RealType, DerivativeOrder> &operator-=(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
*this = *this - expr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class E>
|
||||
rvar<T, order> &operator/=(const expression<T, order, E> &expr)
|
||||
rvar<RealType, DerivativeOrder> &operator/=(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
*this = *this / expr;
|
||||
return *this;
|
||||
}
|
||||
/***************************************************************************************************/
|
||||
rvar<T, order> &operator+=(const T &v)
|
||||
rvar<RealType, DerivativeOrder> &operator+=(const RealType &v)
|
||||
{
|
||||
*this = *this + v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
rvar<T, order> &operator*=(const T &v)
|
||||
rvar<RealType, DerivativeOrder> &operator*=(const RealType &v)
|
||||
{
|
||||
*this = *this * v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
rvar<T, order> &operator-=(const T &v)
|
||||
rvar<RealType, DerivativeOrder> &operator-=(const RealType &v)
|
||||
{
|
||||
*this = *this - v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
rvar<T, order> &operator/=(const T &v)
|
||||
rvar<RealType, DerivativeOrder> &operator/=(const RealType &v)
|
||||
{
|
||||
*this = *this / v;
|
||||
return *this;
|
||||
@@ -490,7 +514,7 @@ public:
|
||||
const inner_t &evaluate() const { return value_; };
|
||||
inner_t &get_value() { return value_; };
|
||||
|
||||
explicit operator T() const { return item(); }
|
||||
explicit operator RealType() const { return item(); }
|
||||
|
||||
explicit operator int() const { return static_cast<int>(item()); }
|
||||
explicit operator long() const { return static_cast<long>(item()); }
|
||||
@@ -503,23 +527,27 @@ public:
|
||||
template<size_t N>
|
||||
auto &get_value_at()
|
||||
{
|
||||
static_assert(N <= order, "Requested depth exceeds variable order.");
|
||||
return get_value_at_impl<N, order>::get(*this);
|
||||
static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
|
||||
return get_value_at_impl<N, DerivativeOrder>::get(*this);
|
||||
}
|
||||
/** @brief same as above but const
|
||||
*/
|
||||
template<size_t N>
|
||||
const auto &get_value_at() const
|
||||
{
|
||||
static_assert(N <= order, "Requested depth exceeds variable order.");
|
||||
return get_value_at_impl<N, order>::get(*this);
|
||||
static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
|
||||
return get_value_at_impl<N, DerivativeOrder>::get(*this);
|
||||
}
|
||||
|
||||
T item() const { return get_item_impl(std::integral_constant<bool, (order > 1)>{}); }
|
||||
RealType item() const
|
||||
{
|
||||
return get_item_impl(std::integral_constant<bool, (DerivativeOrder > 1)>{});
|
||||
}
|
||||
|
||||
void backward()
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
|
||||
= get_active_tape<RealType, DerivativeOrder>();
|
||||
auto it = tape.find(node_);
|
||||
it->update_adjoint_v(inner_t(1.0));
|
||||
while (it != tape.begin()) {
|
||||
@@ -530,32 +558,32 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t order>
|
||||
std::ostream &operator<<(std::ostream &os, const rvar<T, order> var)
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
std::ostream &operator<<(std::ostream &os, const rvar<RealType, DerivativeOrder> var)
|
||||
{
|
||||
os << "rvar<" << order << ">(" << var.item() << "," << var.adjoint() << ")";
|
||||
os << "rvar<" << DerivativeOrder << ">(" << var.item() << "," << var.adjoint() << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
template<typename T, size_t order, typename E>
|
||||
std::ostream &operator<<(std::ostream &os, const expression<T, order, E> &expr)
|
||||
template<typename RealType, size_t DerivativeOrder, typename E>
|
||||
std::ostream &operator<<(std::ostream &os, const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
rvar<T, order> tmp = expr;
|
||||
os << "rvar<" << order << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
|
||||
rvar<RealType, DerivativeOrder> tmp = expr;
|
||||
os << "rvar<" << DerivativeOrder << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
template<typename T, size_t order>
|
||||
rvar<T, order> make_rvar(const T v)
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
rvar<RealType, DerivativeOrder> make_rvar(const RealType v)
|
||||
{
|
||||
static_assert(order > 0, "rvar order must be >= 1");
|
||||
return rvar<T, order>(v);
|
||||
static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
|
||||
return rvar<RealType, DerivativeOrder>(v);
|
||||
}
|
||||
template<typename T, size_t order, typename E>
|
||||
rvar<T, order> make_rvar(const expression<T, order, E> &expr)
|
||||
template<typename RealType, size_t DerivativeOrder, typename E>
|
||||
rvar<RealType, DerivativeOrder> make_rvar(const expression<RealType, DerivativeOrder, E> &expr)
|
||||
{
|
||||
static_assert(order > 0, "rvar order must be >= 1");
|
||||
return rvar<T, order>(expr);
|
||||
static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
|
||||
return rvar<RealType, DerivativeOrder>(expr);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
@@ -565,37 +593,24 @@ namespace detail {
|
||||
* specialization for autodiffing through autodiff. i.e. being able to
|
||||
* compute higher order grads
|
||||
*/
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
struct grad_op_impl
|
||||
{
|
||||
std::vector<rvar<T, order - 1>> operator()(rvar<T, order> &f, std::vector<rvar<T, order> *> &x)
|
||||
std::vector<rvar<RealType, DerivativeOrder - 1>> operator()(
|
||||
rvar<RealType, DerivativeOrder> &f, std::vector<rvar<RealType, DerivativeOrder> *> &x)
|
||||
{
|
||||
auto &tape = get_active_tape<T, order>();
|
||||
auto &tape = get_active_tape<RealType, DerivativeOrder>();
|
||||
tape.zero_grad();
|
||||
f.backward();
|
||||
|
||||
std::vector<rvar<T, order - 1>> gradient_vector;
|
||||
std::vector<rvar<RealType, DerivativeOrder - 1>> gradient_vector;
|
||||
gradient_vector.reserve(x.size());
|
||||
|
||||
for (auto xi : x) {
|
||||
// make a new rvar<T,order-1> holding the adjoint value
|
||||
for (auto &xi : x) {
|
||||
gradient_vector.emplace_back(xi->adjoint());
|
||||
}
|
||||
return gradient_vector;
|
||||
}
|
||||
/*
|
||||
std::vector<rvar_t<T, order - 1> *> operator()(rvar<T, order> &f,
|
||||
std::vector<rvar<T, order> *> &x)
|
||||
{
|
||||
gradient_tape<T, order, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, order>();
|
||||
tape.zero_grad();
|
||||
f.backward();
|
||||
std::vector<rvar_t<T, order - 1> *> gradient_vector;
|
||||
for (auto xi : x) {
|
||||
gradient_vector.push_back(&(xi->adjoint()));
|
||||
}
|
||||
return gradient_vector;
|
||||
}*/
|
||||
};
|
||||
/** @brief helper overload for grad implementation.
|
||||
* @return vector<T> of gradients of the autodiff graph.
|
||||
@@ -610,7 +625,7 @@ struct grad_op_impl<T, 1>
|
||||
tape.zero_grad();
|
||||
f.backward();
|
||||
std::vector<T> gradient_vector;
|
||||
for (auto xi : x) {
|
||||
for (auto &xi : x) {
|
||||
gradient_vector.push_back(xi->adjoint());
|
||||
}
|
||||
return gradient_vector;
|
||||
@@ -621,52 +636,51 @@ struct grad_op_impl<T, 1>
|
||||
* @return nested vector representing N-d tensor of
|
||||
* higher order derivatives
|
||||
*/
|
||||
template<size_t N, typename T, size_t order_1, size_t order_2, typename Enable = void>
|
||||
template<size_t N,
|
||||
typename RealType,
|
||||
size_t DerivativeOrder_1,
|
||||
size_t DerivativeOrder_2,
|
||||
typename Enable = void>
|
||||
struct grad_nd_impl
|
||||
{
|
||||
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
auto operator()(rvar<RealType, DerivativeOrder_1> &f,
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
|
||||
{
|
||||
static_assert(N > 1, "N must be greater than 1 for this template");
|
||||
|
||||
auto current_grad = grad(f, x); // vector<rvar<T,order_1-1>> or vector<T>
|
||||
auto current_grad = grad(f, x); // vector<rvar<T,DerivativeOrder_1-1>> or vector<T>
|
||||
|
||||
std::vector<decltype(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(current_grad[0], x))>
|
||||
std::vector<decltype(grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(
|
||||
current_grad[0], x))>
|
||||
result;
|
||||
result.reserve(current_grad.size());
|
||||
|
||||
for (auto &g_i : current_grad) {
|
||||
result.push_back(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(g_i, x));
|
||||
result.push_back(
|
||||
grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(g_i, x));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
/*
|
||||
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
{
|
||||
static_assert(N > 1, "N must be greater than 1 for this template");
|
||||
auto current_grad = grad(f, x);
|
||||
std::vector<decltype(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(*current_grad[0], x))>
|
||||
result;
|
||||
for (auto &g_i : current_grad) {
|
||||
result.push_back(grad_nd_impl<N - 1, T, order_1 - 1, order_2>()(*g_i, x));
|
||||
}
|
||||
return result;
|
||||
}*/
|
||||
};
|
||||
/** @brief spcialization for order = 1,
|
||||
* @return vector<rvar<T,order_1-1>> gradients */
|
||||
template<typename T, size_t order_1, size_t order_2>
|
||||
struct grad_nd_impl<1, T, order_1, order_2>
|
||||
* @return vector<rvar<T,DerivativeOrder_1-1>> gradients */
|
||||
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
|
||||
struct grad_nd_impl<1, RealType, DerivativeOrder_1, DerivativeOrder_2>
|
||||
{
|
||||
auto operator()(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x) { return grad(f, x); }
|
||||
auto operator()(rvar<RealType, DerivativeOrder_1> &f,
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
|
||||
{
|
||||
return grad(f, x);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename ptr>
|
||||
struct rvar_order;
|
||||
|
||||
template<typename T, size_t order>
|
||||
struct rvar_order<rvar<T, order> *>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
struct rvar_order<rvar<RealType, DerivativeOrder> *>
|
||||
{
|
||||
static constexpr size_t value = order;
|
||||
static constexpr size_t value = DerivativeOrder;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
@@ -676,52 +690,52 @@ struct rvar_order<rvar<T, order> *>
|
||||
* @param f -> computational graph
|
||||
* @param x -> variables gradients to record. Note ALL gradients of the graph
|
||||
* are computed simultaneously, only the ones w.r.t. x are returned
|
||||
* @return vector<rvar<T,order_1 - 1> of gradients. in the case of order_1 = 1
|
||||
* rvar<T,order_1-1> decays to T
|
||||
* @return vector<rvar<T,DerivativeOrder_1 - 1> of gradients. in the case of DerivativeOrder_1 = 1
|
||||
* rvar<T,DerivativeOrder_1-1> decays to T
|
||||
*
|
||||
* safe to call recursively with grad(grad(grad...
|
||||
*/
|
||||
template<typename T, size_t order_1, size_t order_2>
|
||||
auto grad(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
|
||||
auto grad(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
|
||||
{
|
||||
static_assert(order_1 <= order_2,
|
||||
static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
|
||||
"variable differentiating w.r.t. must have order >= function order");
|
||||
std::vector<rvar<T, order_1> *> xx;
|
||||
std::vector<rvar<RealType, DerivativeOrder_1> *> xx;
|
||||
for (auto &xi : x)
|
||||
xx.push_back(&(xi->template get_value_at<order_1>()));
|
||||
return detail::grad_op_impl<T, order_1>{}(f, xx);
|
||||
xx.push_back(&(xi->template get_value_at<DerivativeOrder_1>()));
|
||||
return detail::grad_op_impl<RealType, DerivativeOrder_1>{}(f, xx);
|
||||
}
|
||||
/** @brief variadic overload of above
|
||||
*/
|
||||
template<typename T, size_t order_1, typename First, typename... Other>
|
||||
auto grad(rvar<T, order_1> &f, First first, Other... other)
|
||||
template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
|
||||
auto grad(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
|
||||
{
|
||||
constexpr size_t order_2 = detail::rvar_order<First>::value;
|
||||
static_assert(order_1 <= order_2,
|
||||
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
|
||||
static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
|
||||
"variable differentiating w.r.t. must have order >= function order");
|
||||
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
|
||||
return grad(f, x_vec);
|
||||
}
|
||||
|
||||
/** @brief computes hessian matrix of computational graph w.r.t.
|
||||
* vector of variables x.
|
||||
* @return std::vector<std::vector<rvar<T,order_1-2>> hessian matrix
|
||||
* @return std::vector<std::vector<rvar<T,DerivativeOrder_1-2>> hessian matrix
|
||||
* rvar<T,2> decays to T
|
||||
*
|
||||
* NOT recursion safe, cannot do hess(hess(
|
||||
*/
|
||||
template<typename T, size_t order_1, size_t order_2>
|
||||
auto hess(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
|
||||
auto hess(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
|
||||
{
|
||||
return detail::grad_nd_impl<2, T, order_1, order_2>{}(f, x);
|
||||
return detail::grad_nd_impl<2, RealType, DerivativeOrder_1, DerivativeOrder_2>{}(f, x);
|
||||
}
|
||||
/** @brief variadic overload of above
|
||||
*/
|
||||
template<typename T, size_t order_1, typename First, typename... Other>
|
||||
auto hess(rvar<T, order_1> &f, First first, Other... other)
|
||||
template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
|
||||
auto hess(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
|
||||
{
|
||||
constexpr size_t order_2 = detail::rvar_order<First>::value;
|
||||
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
|
||||
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
|
||||
return hess(f, x_vec);
|
||||
}
|
||||
|
||||
@@ -731,13 +745,15 @@ auto hess(rvar<T, order_1> &f, First first, Other... other)
|
||||
*
|
||||
* NOT recursively safe, cannot do grad_nd(grad_nd(... etc...
|
||||
*/
|
||||
template<size_t N, typename T, size_t order_1, size_t order_2>
|
||||
auto grad_nd(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
template<size_t N, typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
|
||||
auto grad_nd(rvar<RealType, DerivativeOrder_1> &f,
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
|
||||
{
|
||||
static_assert(order_1 >= N, "Function order must be at least N");
|
||||
static_assert(order_2 >= order_1, "Variable order must be at least function order");
|
||||
static_assert(DerivativeOrder_1 >= N, "Function order must be at least N");
|
||||
static_assert(DerivativeOrder_2 >= DerivativeOrder_1,
|
||||
"Variable order must be at least function order");
|
||||
|
||||
return detail::grad_nd_impl<N, T, order_1, order_2>()(f, x);
|
||||
return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_2>()(f, x);
|
||||
}
|
||||
|
||||
/** @brief variadic overload of above
|
||||
@@ -745,11 +761,11 @@ auto grad_nd(rvar<T, order_1> &f, std::vector<rvar<T, order_2> *> &x)
|
||||
template<size_t N, typename ftype, typename First, typename... Other>
|
||||
auto grad_nd(ftype &f, First first, Other... other)
|
||||
{
|
||||
using T = typename ftype::value_type;
|
||||
constexpr size_t order_1 = detail::rvar_order<ftype *>::value;
|
||||
constexpr size_t order_2 = detail::rvar_order<First>::value;
|
||||
std::vector<rvar<T, order_2> *> x_vec = {first, other...};
|
||||
return detail::grad_nd_impl<N, T, order_1, order_1>{}(f, x_vec);
|
||||
using RealType = typename ftype::value_type;
|
||||
constexpr size_t DerivativeOrder_1 = detail::rvar_order<ftype *>::value;
|
||||
constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
|
||||
std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
|
||||
return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_1>{}(f, x_vec);
|
||||
}
|
||||
} // namespace reverse_mode
|
||||
} // namespace differentiation
|
||||
@@ -758,10 +774,10 @@ auto grad_nd(ftype &f, First first, Other... other)
|
||||
namespace std {
|
||||
|
||||
// copied from forward mode
|
||||
template<typename T, size_t order>
|
||||
class numeric_limits<boost::math::differentiation::reverse_mode::rvar<T, order>>
|
||||
: public numeric_limits<
|
||||
typename boost::math::differentiation::reverse_mode::rvar<T, order>::value_type>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
class numeric_limits<boost::math::differentiation::reverse_mode::rvar<RealType, DerivativeOrder>>
|
||||
: public numeric_limits<typename boost::math::differentiation::reverse_mode::
|
||||
rvar<RealType, DerivativeOrder>::value_type>
|
||||
{};
|
||||
} // namespace std
|
||||
#endif
|
||||
|
||||
@@ -12,19 +12,26 @@ namespace math {
|
||||
namespace differentiation {
|
||||
namespace reverse_mode {
|
||||
/****************************************************************************************************************/
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
struct add_expr
|
||||
: public abstract_binary_expression<T, order, LHS, RHS, add_expr<T, order, LHS, RHS>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
struct add_expr : public abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
add_expr<RealType, DerivativeOrder, LHS, RHS>>
|
||||
{
|
||||
/* @brief addition
|
||||
* rvar+rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
// Explicitly define constructor to forward to base class
|
||||
explicit add_expr(const expression<T, order, LHS> &left_hand_expr,
|
||||
const expression<T, order, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<T, order, LHS, RHS, add_expr<T, order, LHS, RHS>>(
|
||||
left_hand_expr, right_hand_expr)
|
||||
explicit add_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
|
||||
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
add_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
|
||||
right_hand_expr)
|
||||
{}
|
||||
|
||||
inner_t evaluate() const { return this->lhs.evaluate() + this->rhs.evaluate(); }
|
||||
@@ -41,37 +48,51 @@ struct add_expr
|
||||
return inner_t(1.0);
|
||||
}
|
||||
};
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct add_const_expr
|
||||
: public abstract_unary_expression<T, order, ARG, add_const_expr<T, order, ARG>>
|
||||
: public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
add_const_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/* @brief
|
||||
* rvar+float or float+rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
explicit add_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, add_const_expr<T, order, ARG>>(arg_expr, v){};
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
explicit add_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
add_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
inner_t evaluate() const { return this->arg.evaluate() + inner_t(this->constant); }
|
||||
static const inner_t derivative(const inner_t & /*argv*/,
|
||||
const inner_t & /*v*/,
|
||||
const T & /*constant*/)
|
||||
const RealType & /*constant*/)
|
||||
{
|
||||
return inner_t(1.0);
|
||||
}
|
||||
};
|
||||
/****************************************************************************************************************/
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
struct mult_expr
|
||||
: public abstract_binary_expression<T, order, LHS, RHS, mult_expr<T, order, LHS, RHS>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
struct mult_expr : public abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
mult_expr<RealType, DerivativeOrder, LHS, RHS>>
|
||||
{
|
||||
/* @brief multiplication
|
||||
* rvar * rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
explicit mult_expr(const expression<T, order, LHS> &left_hand_expr,
|
||||
const expression<T, order, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<T, order, LHS, RHS, mult_expr<T, order, LHS, RHS>>(
|
||||
left_hand_expr, right_hand_expr)
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
explicit mult_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
|
||||
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
mult_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
|
||||
right_hand_expr)
|
||||
{}
|
||||
|
||||
inner_t evaluate() const { return this->lhs.evaluate() * this->rhs.evaluate(); };
|
||||
@@ -88,40 +109,54 @@ struct mult_expr
|
||||
return l;
|
||||
};
|
||||
};
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct mult_const_expr
|
||||
: public abstract_unary_expression<T, order, ARG, mult_const_expr<T, order, ARG>>
|
||||
: public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
mult_const_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/* @brief
|
||||
* rvar+float or float+rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit mult_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, mult_const_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit mult_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
mult_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
|
||||
inner_t evaluate() const { return this->arg.evaluate() * inner_t(this->constant); }
|
||||
static const inner_t derivative(const inner_t & /*argv*/,
|
||||
const inner_t & /*v*/,
|
||||
const T &constant)
|
||||
const RealType &constant)
|
||||
{
|
||||
return inner_t(constant);
|
||||
}
|
||||
};
|
||||
/****************************************************************************************************************/
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
struct sub_expr
|
||||
: public abstract_binary_expression<T, order, LHS, RHS, sub_expr<T, order, LHS, RHS>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
struct sub_expr : public abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
sub_expr<RealType, DerivativeOrder, LHS, RHS>>
|
||||
{
|
||||
/* @brief addition
|
||||
* rvar-rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
// Explicitly define constructor to forward to base class
|
||||
explicit sub_expr(const expression<T, order, LHS> &left_hand_expr,
|
||||
const expression<T, order, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<T, order, LHS, RHS, sub_expr<T, order, LHS, RHS>>(
|
||||
left_hand_expr, right_hand_expr)
|
||||
explicit sub_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
|
||||
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
sub_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
|
||||
right_hand_expr)
|
||||
{}
|
||||
|
||||
inner_t evaluate() const { return this->lhs.evaluate() - this->rhs.evaluate(); }
|
||||
@@ -140,19 +175,26 @@ struct sub_expr
|
||||
};
|
||||
|
||||
/****************************************************************************************************************/
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
struct div_expr
|
||||
: public abstract_binary_expression<T, order, LHS, RHS, div_expr<T, order, LHS, RHS>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
struct div_expr : public abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
div_expr<RealType, DerivativeOrder, LHS, RHS>>
|
||||
{
|
||||
/* @brief multiplication
|
||||
* rvar / rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
// Explicitly define constructor to forward to base class
|
||||
explicit div_expr(const expression<T, order, LHS> &left_hand_expr,
|
||||
const expression<T, order, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<T, order, LHS, RHS, div_expr<T, order, LHS, RHS>>(
|
||||
left_hand_expr, right_hand_expr)
|
||||
explicit div_expr(const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
|
||||
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
|
||||
: abstract_binary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
LHS,
|
||||
RHS,
|
||||
div_expr<RealType, DerivativeOrder, LHS, RHS>>(left_hand_expr,
|
||||
right_hand_expr)
|
||||
{}
|
||||
|
||||
inner_t evaluate() const { return this->lhs.evaluate() / this->rhs.evaluate(); };
|
||||
@@ -160,182 +202,212 @@ struct div_expr
|
||||
const inner_t &r,
|
||||
const inner_t & /*v*/)
|
||||
{
|
||||
return static_cast<T>(1.0) / r;
|
||||
return static_cast<RealType>(1.0) / r;
|
||||
};
|
||||
static const inner_t right_derivative(const inner_t &l, const inner_t &r, const inner_t & /*v*/)
|
||||
{
|
||||
return -l / (r * r);
|
||||
};
|
||||
};
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct div_by_const_expr
|
||||
: public abstract_unary_expression<T, order, ARG, div_by_const_expr<T, order, ARG>>
|
||||
: public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
div_by_const_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/* @brief
|
||||
* rvar/float
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit div_by_const_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, div_by_const_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit div_by_const_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
div_by_const_expr<RealType, DerivativeOrder, ARG>>(arg_expr,
|
||||
v){};
|
||||
|
||||
inner_t evaluate() const { return this->arg.evaluate() / inner_t(this->constant); }
|
||||
static const inner_t derivative(const inner_t & /*argv*/,
|
||||
const inner_t & /*v*/,
|
||||
const T &constant)
|
||||
const RealType &constant)
|
||||
{
|
||||
return inner_t(1.0 / constant);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct const_div_by_expr
|
||||
: public abstract_unary_expression<T, order, ARG, const_div_by_expr<T, order, ARG>>
|
||||
: public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
const_div_by_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/** @brief
|
||||
* float/rvar
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit const_div_by_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, const_div_by_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit const_div_by_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
const_div_by_expr<RealType, DerivativeOrder, ARG>>(arg_expr,
|
||||
v){};
|
||||
|
||||
inner_t evaluate() const { return inner_t(this->constant) / this->arg.evaluate(); }
|
||||
static const inner_t derivative(const inner_t &argv, const inner_t & /*v*/, const T &constant)
|
||||
static const inner_t derivative(const inner_t &argv,
|
||||
const inner_t & /*v*/,
|
||||
const RealType &constant)
|
||||
{
|
||||
return -inner_t{constant} / (argv * argv);
|
||||
}
|
||||
};
|
||||
/****************************************************************************************************************/
|
||||
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
mult_expr<T, order, LHS, RHS> operator*(const expression<T, order, LHS> &lhs,
|
||||
const expression<T, order, RHS> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
mult_expr<RealType, DerivativeOrder, LHS, RHS> operator*(
|
||||
const expression<RealType, DerivativeOrder, LHS> &lhs,
|
||||
const expression<RealType, DerivativeOrder, RHS> &rhs)
|
||||
{
|
||||
return mult_expr<T, order, LHS, RHS>(lhs, rhs);
|
||||
return mult_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
|
||||
}
|
||||
|
||||
/** @brief type promotion is handled by casting the numeric type to
|
||||
* the type inside expression. This is to avoid converting the
|
||||
* entire tape in case you have something like double * rvar<float>
|
||||
* */
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
mult_const_expr<T, order, ARG> operator*(const expression<T, order, ARG> &arg, const U &v)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
mult_const_expr<RealType1, DerivativeOrder, ARG> operator*(
|
||||
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
|
||||
{
|
||||
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return mult_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
mult_const_expr<T, order, ARG> operator*(const U &v, const expression<T, order, ARG> &arg)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
mult_const_expr<RealType1, DerivativeOrder, ARG> operator*(
|
||||
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return mult_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
/****************************************************************************************************************/
|
||||
/* + */
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
add_expr<T, order, LHS, RHS> operator+(const expression<T, order, LHS> &lhs,
|
||||
const expression<T, order, RHS> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
add_expr<RealType, DerivativeOrder, LHS, RHS> operator+(
|
||||
const expression<RealType, DerivativeOrder, LHS> &lhs,
|
||||
const expression<RealType, DerivativeOrder, RHS> &rhs)
|
||||
{
|
||||
return add_expr<T, order, LHS, RHS>(lhs, rhs);
|
||||
return add_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
add_const_expr<T, order, ARG> operator+(const expression<T, order, ARG> &arg, const U &v)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
add_const_expr<RealType1, DerivativeOrder, ARG> operator+(
|
||||
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
|
||||
{
|
||||
return add_const_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
add_const_expr<T, order, ARG> operator+(const U &v, const expression<T, order, ARG> &arg)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
add_const_expr<RealType1, DerivativeOrder, ARG> operator+(
|
||||
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return add_const_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
/****************************************************************************************************************/
|
||||
/* - overload */
|
||||
/** @brief
|
||||
* negation (-1.0*rvar) */
|
||||
template<typename T, size_t order, typename ARG>
|
||||
mult_const_expr<T, order, ARG> operator-(const expression<T, order, ARG> &arg)
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
mult_const_expr<RealType, DerivativeOrder, ARG> operator-(
|
||||
const expression<RealType, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return mult_const_expr<T, order, ARG>(arg, static_cast<T>(-1.0));
|
||||
return mult_const_expr<RealType, DerivativeOrder, ARG>(arg, static_cast<RealType>(-1.0));
|
||||
}
|
||||
|
||||
/** @brief
|
||||
* subtraction rvar-rvar */
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
sub_expr<T, order, LHS, RHS> operator-(const expression<T, order, LHS> &lhs,
|
||||
const expression<T, order, RHS> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
sub_expr<RealType, DerivativeOrder, LHS, RHS> operator-(
|
||||
const expression<RealType, DerivativeOrder, LHS> &lhs,
|
||||
const expression<RealType, DerivativeOrder, RHS> &rhs)
|
||||
{
|
||||
return sub_expr<T, order, LHS, RHS>(lhs, rhs);
|
||||
return sub_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
|
||||
}
|
||||
|
||||
/** @brief
|
||||
* subtraction float - rvar */
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
add_const_expr<T, order, ARG> operator-(const expression<T, order, ARG> &arg, const U &v)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
add_const_expr<RealType1, DerivativeOrder, ARG> operator-(
|
||||
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
|
||||
{
|
||||
/* rvar - float = rvar + (-float) */
|
||||
return add_const_expr<T, order, ARG>(arg, static_cast<T>(-v));
|
||||
return add_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(-v));
|
||||
}
|
||||
|
||||
/** @brief
|
||||
* subtraction float - rvar
|
||||
* @return add_expr<neg_expr<ARG>>
|
||||
*/
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
auto operator-(const U &v, const expression<T, order, ARG> &arg)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
auto operator-(const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
auto neg = -arg;
|
||||
return neg + static_cast<T>(v);
|
||||
return neg + static_cast<RealType1>(v);
|
||||
}
|
||||
/****************************************************************************************************************/
|
||||
/* / */
|
||||
template<typename T, size_t order, typename LHS, typename RHS>
|
||||
div_expr<T, order, LHS, RHS> operator/(const expression<T, order, LHS> &lhs,
|
||||
const expression<T, order, RHS> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder, typename LHS, typename RHS>
|
||||
div_expr<RealType, DerivativeOrder, LHS, RHS> operator/(
|
||||
const expression<RealType, DerivativeOrder, LHS> &lhs,
|
||||
const expression<RealType, DerivativeOrder, RHS> &rhs)
|
||||
{
|
||||
return div_expr<T, order, LHS, RHS>(lhs, rhs);
|
||||
return div_expr<RealType, DerivativeOrder, LHS, RHS>(lhs, rhs);
|
||||
}
|
||||
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
const_div_by_expr<T, order, ARG> operator/(const U &v, const expression<T, order, ARG> &arg)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
const_div_by_expr<RealType1, DerivativeOrder, ARG> operator/(
|
||||
const RealType2 &v, const expression<RealType1, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return const_div_by_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return const_div_by_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
typename ARG,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
div_by_const_expr<T, order, ARG> operator/(const expression<T, order, ARG> &arg, const U &v)
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
div_by_const_expr<RealType1, DerivativeOrder, ARG> operator/(
|
||||
const expression<RealType1, DerivativeOrder, ARG> &arg, const RealType2 &v)
|
||||
{
|
||||
return div_by_const_expr<T, order, ARG>(arg, static_cast<T>(v));
|
||||
return div_by_const_expr<RealType1, DerivativeOrder, ARG>(arg, static_cast<RealType1>(v));
|
||||
}
|
||||
|
||||
} // namespace reverse_mode
|
||||
|
||||
@@ -10,147 +10,153 @@ namespace boost {
|
||||
namespace math {
|
||||
namespace differentiation {
|
||||
namespace reverse_mode {
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator==(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator==(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() == rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator==(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator==(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() == static_cast<T>(rhs);
|
||||
return lhs.evaluate() == static_cast<RealType1>(rhs);
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator==(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator==(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs == rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator!=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator!=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() != rhs.evaluate();
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator!=(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator!=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() != rhs;
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator!=(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator!=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs != rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator<(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator<(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() < rhs.evaluate();
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator<(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator<(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() < rhs;
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator<(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator<(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs < rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator>(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator>(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() > rhs.evaluate();
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator>(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator>(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() > rhs;
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator>(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator>(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs > rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator<=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator<=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() <= rhs.evaluate();
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator<=(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator<=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() <= rhs;
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator<=(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator<=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs <= rhs.evaluate();
|
||||
}
|
||||
|
||||
template<typename T, size_t order_1, size_t order_2, class E, class F>
|
||||
bool operator>=(const expression<T, order_1, E> &lhs, const expression<T, order_2, F> &rhs)
|
||||
template<typename RealType, size_t DerivativeOrder1, size_t DerivativeOrder2, class LhsExpr, class RhsExpr>
|
||||
bool operator>=(const expression<RealType, DerivativeOrder1, LhsExpr> &lhs,
|
||||
const expression<RealType, DerivativeOrder2, RhsExpr> &rhs)
|
||||
{
|
||||
return lhs.evaluate() >= rhs.evaluate();
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator>=(const expression<T, order, E> &lhs, const U &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator>=(const expression<RealType1, DerivativeOrder, ArgExpr> &lhs, const RealType2 &rhs)
|
||||
{
|
||||
return lhs.evaluate() >= rhs;
|
||||
}
|
||||
template<typename U,
|
||||
typename T,
|
||||
size_t order,
|
||||
class E,
|
||||
typename = typename std::enable_if<!detail::is_expression<U>::value>::type>
|
||||
bool operator>=(const U &lhs, const expression<T, order, E> &rhs)
|
||||
template<typename RealType2,
|
||||
typename RealType1,
|
||||
size_t DerivativeOrder,
|
||||
class ArgExpr,
|
||||
typename = typename std::enable_if<!detail::is_expression<RealType2>::value>::type>
|
||||
bool operator>=(const RealType2 &lhs, const expression<RealType1, DerivativeOrder, ArgExpr> &rhs)
|
||||
{
|
||||
return lhs >= rhs.evaluate();
|
||||
}
|
||||
|
||||
@@ -16,73 +16,84 @@ namespace math {
|
||||
namespace differentiation {
|
||||
namespace reverse_mode {
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erf_expr;
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erfc_expr;
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erf_inv_expr;
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erfc_inv_expr;
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
erf_expr<T, order, ARG> erf(const expression<T, order, ARG> &arg)
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
erf_expr<RealType, DerivativeOrder, ARG> erf(const expression<RealType, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return erf_expr<T, order, ARG>(arg, 0.0);
|
||||
return erf_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
|
||||
}
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
erfc_expr<T, order, ARG> erfc(const expression<T, order, ARG> &arg)
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
erfc_expr<RealType, DerivativeOrder, ARG> erfc(const expression<RealType, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return erfc_expr<T, order, ARG>(arg, 0.0);
|
||||
return erfc_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
|
||||
}
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
erf_inv_expr<T, order, ARG> erf_inv(const expression<T, order, ARG> &arg)
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
erf_inv_expr<RealType, DerivativeOrder, ARG> erf_inv(
|
||||
const expression<RealType, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return erf_inv_expr<T, order, ARG>(arg, 0.0);
|
||||
return erf_inv_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
|
||||
}
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
erfc_inv_expr<T, order, ARG> erfc_inv(const expression<T, order, ARG> &arg)
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
erfc_inv_expr<RealType, DerivativeOrder, ARG> erfc_inv(
|
||||
const expression<RealType, DerivativeOrder, ARG> &arg)
|
||||
{
|
||||
return erfc_inv_expr<T, order, ARG>(arg, 0.0);
|
||||
return erfc_inv_expr<RealType, DerivativeOrder, ARG>(arg, 0.0);
|
||||
}
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
struct erf_expr : public abstract_unary_expression<T, order, ARG, erf_expr<T, order, ARG>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erf_expr : public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erf_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/** @brief erf(x)
|
||||
*
|
||||
* d/dx erf(x) = 2*exp(x^2)/sqrt(pi)
|
||||
*
|
||||
* */
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit erf_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, erf_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit erf_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr, const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erf_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<(DerivativeOrder > 1)>(
|
||||
[this](auto &&x) { return reverse_mode::erf(std::forward<decltype(x)>(x)); },
|
||||
[this](auto &&x) { return boost::math::erf(std::forward<decltype(x)>(x)); },
|
||||
this->arg.evaluate());
|
||||
}
|
||||
static const inner_t derivative(const inner_t &argv,
|
||||
const inner_t & /*v*/,
|
||||
const T & /*constant*/)
|
||||
const RealType & /*constant*/)
|
||||
{
|
||||
BOOST_MATH_STD_USING
|
||||
return static_cast<T>(2.0) * exp(-argv * argv) / sqrt(constants::pi<T>());
|
||||
return static_cast<RealType>(2.0) * exp(-argv * argv) / sqrt(constants::pi<RealType>());
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
struct erfc_expr : public abstract_unary_expression<T, order, ARG, erfc_expr<T, order, ARG>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erfc_expr : public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erfc_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/** @brief erfc(x)
|
||||
*
|
||||
@@ -90,29 +101,35 @@ struct erfc_expr : public abstract_unary_expression<T, order, ARG, erfc_expr<T,
|
||||
*
|
||||
* */
|
||||
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit erfc_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, erfc_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit erfc_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr, const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erfc_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
|
||||
[this](auto &&x) { return reverse_mode::erfc(std::forward<decltype(x)>(x)); },
|
||||
[this](auto &&x) { return boost::math::erfc(std::forward<decltype(x)>(x)); },
|
||||
this->arg.evaluate());
|
||||
}
|
||||
static const inner_t derivative(const inner_t &argv,
|
||||
const inner_t & /*v*/,
|
||||
const T & /*constant*/)
|
||||
const RealType & /*constant*/)
|
||||
{
|
||||
BOOST_MATH_STD_USING
|
||||
return static_cast<T>(-2.0) * exp(-argv * argv) / sqrt(constants::pi<T>());
|
||||
return static_cast<RealType>(-2.0) * exp(-argv * argv) / sqrt(constants::pi<RealType>());
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
struct erf_inv_expr : public abstract_unary_expression<T, order, ARG, erf_inv_expr<T, order, ARG>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erf_inv_expr : public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erf_inv_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/** @brief erf(x)
|
||||
*
|
||||
@@ -120,39 +137,47 @@ struct erf_inv_expr : public abstract_unary_expression<T, order, ARG, erf_inv_ex
|
||||
*
|
||||
* */
|
||||
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit erf_inv_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, erf_inv_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit erf_inv_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erf_inv_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
|
||||
[this](auto &&x) { return reverse_mode::erf_inv(std::forward<decltype(x)>(x)); },
|
||||
[this](auto &&x) { return boost::math::erf_inv(std::forward<decltype(x)>(x)); },
|
||||
this->arg.evaluate());
|
||||
}
|
||||
static const inner_t derivative(const inner_t &argv,
|
||||
const inner_t & /*v*/,
|
||||
const T & /*constant*/)
|
||||
const RealType & /*constant*/)
|
||||
{
|
||||
BOOST_MATH_STD_USING
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
|
||||
[](auto &&x) {
|
||||
return static_cast<T>(0.5) * sqrt(constants::pi<T>())
|
||||
return static_cast<RealType>(0.5) * sqrt(constants::pi<RealType>())
|
||||
* reverse_mode::exp(
|
||||
reverse_mode::pow(reverse_mode::erf_inv(x), static_cast<T>(2.0)));
|
||||
reverse_mode::pow(reverse_mode::erf_inv(x), static_cast<RealType>(2.0)));
|
||||
},
|
||||
[](auto &&x) {
|
||||
return static_cast<T>(0.5) * sqrt(constants::pi<T>())
|
||||
* exp(pow(boost::math::erf_inv(x), static_cast<T>(2.0)));
|
||||
return static_cast<RealType>(0.5) * sqrt(constants::pi<RealType>())
|
||||
* exp(pow(boost::math::erf_inv(x), static_cast<RealType>(2.0)));
|
||||
},
|
||||
argv);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t order, typename ARG>
|
||||
struct erfc_inv_expr : public abstract_unary_expression<T, order, ARG, erfc_inv_expr<T, order, ARG>>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG>
|
||||
struct erfc_inv_expr
|
||||
: public abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erfc_inv_expr<RealType, DerivativeOrder, ARG>>
|
||||
{
|
||||
/** @brief erfc(x)
|
||||
*
|
||||
@@ -160,32 +185,36 @@ struct erfc_inv_expr : public abstract_unary_expression<T, order, ARG, erfc_inv_
|
||||
*
|
||||
* */
|
||||
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
|
||||
explicit erfc_inv_expr(const expression<T, order, ARG> &arg_expr, const T v)
|
||||
: abstract_unary_expression<T, order, ARG, erfc_inv_expr<T, order, ARG>>(arg_expr, v){};
|
||||
explicit erfc_inv_expr(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType v)
|
||||
: abstract_unary_expression<RealType,
|
||||
DerivativeOrder,
|
||||
ARG,
|
||||
erfc_inv_expr<RealType, DerivativeOrder, ARG>>(arg_expr, v){};
|
||||
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
|
||||
[this](auto &&x) { return reverse_mode::erfc_inv(std::forward<decltype(x)>(x)); },
|
||||
[this](auto &&x) { return boost::math::erfc_inv(std::forward<decltype(x)>(x)); },
|
||||
this->arg.evaluate());
|
||||
}
|
||||
static const inner_t derivative(const inner_t &argv,
|
||||
const inner_t & /*v*/,
|
||||
const T & /*constant*/)
|
||||
const RealType & /*constant*/)
|
||||
{
|
||||
BOOST_MATH_STD_USING
|
||||
return detail::if_functional_dispatch<(order > 1)>(
|
||||
return detail::if_functional_dispatch<((DerivativeOrder > 1))>(
|
||||
[](auto &&x) {
|
||||
return static_cast<T>(-0.5) * sqrt(constants::pi<T>())
|
||||
* reverse_mode::exp(
|
||||
reverse_mode::pow(reverse_mode::erfc_inv(x), static_cast<T>(2.0)));
|
||||
return static_cast<RealType>(-0.5) * sqrt(constants::pi<RealType>())
|
||||
* reverse_mode::exp(reverse_mode::pow(reverse_mode::erfc_inv(x),
|
||||
static_cast<RealType>(2.0)));
|
||||
},
|
||||
[](auto &&x) {
|
||||
return static_cast<T>(-0.5) * sqrt(constants::pi<T>())
|
||||
* exp(pow(boost::math::erfc_inv(x), static_cast<T>(2.0)));
|
||||
return static_cast<RealType>(-0.5) * sqrt(constants::pi<RealType>())
|
||||
* exp(pow(boost::math::erfc_inv(x), static_cast<RealType>(2.0)));
|
||||
},
|
||||
argv);
|
||||
}
|
||||
|
||||
@@ -16,19 +16,24 @@ namespace reverse_mode {
|
||||
struct expression_base
|
||||
{};
|
||||
|
||||
template<typename T, size_t order, class derived_expression>
|
||||
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
|
||||
struct expression;
|
||||
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
class rvar;
|
||||
|
||||
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
|
||||
template<typename RealType,
|
||||
size_t DerivativeOrder,
|
||||
typename LHS,
|
||||
typename RHS,
|
||||
typename ConcreteBinaryOperation>
|
||||
struct abstract_binary_expression;
|
||||
|
||||
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteUnaryOperation>
|
||||
|
||||
struct abstract_unary_expression;
|
||||
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
class gradient_node; // forward declaration for tape
|
||||
|
||||
namespace detail {
|
||||
@@ -59,50 +64,54 @@ struct count_rvar_impl
|
||||
{
|
||||
static constexpr std::size_t value = 0;
|
||||
};
|
||||
template<typename U, size_t order>
|
||||
struct count_rvar_impl<rvar<U, order>, order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
struct count_rvar_impl<rvar<RealType, DerivativeOrder>, DerivativeOrder>
|
||||
{
|
||||
static constexpr std::size_t value = 1;
|
||||
};
|
||||
|
||||
template<typename T, std::size_t order>
|
||||
struct count_rvar_impl<T,
|
||||
order,
|
||||
std::enable_if_t<has_binary_sub_types<T>::value
|
||||
&& !std::is_same<T, rvar<typename T::value_type, order>>::value
|
||||
&& !has_unary_sub_type<T>::value>>
|
||||
template<typename RealType, std::size_t DerivativeOrder>
|
||||
struct count_rvar_impl<
|
||||
RealType,
|
||||
DerivativeOrder,
|
||||
std::enable_if_t<has_binary_sub_types<RealType>::value
|
||||
&& !std::is_same<RealType, rvar<typename RealType::value_type, DerivativeOrder>>::value
|
||||
&& !has_unary_sub_type<RealType>::value>>
|
||||
{
|
||||
static constexpr std::size_t value = count_rvar_impl<typename T::lhs_type, order>::value
|
||||
+ count_rvar_impl<typename T::rhs_type, order>::value;
|
||||
static constexpr std::size_t value
|
||||
= count_rvar_impl<typename RealType::lhs_type, DerivativeOrder>::value
|
||||
+ count_rvar_impl<typename RealType::rhs_type, DerivativeOrder>::value;
|
||||
};
|
||||
|
||||
template<typename T, size_t order>
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
struct count_rvar_impl<
|
||||
T,
|
||||
order,
|
||||
typename std::enable_if_t<has_unary_sub_type<T>::value
|
||||
&& !std::is_same<T, rvar<typename T::value_type, order>>::value
|
||||
&& !has_binary_sub_types<T>::value>>
|
||||
RealType,
|
||||
DerivativeOrder,
|
||||
typename std::enable_if_t<
|
||||
has_unary_sub_type<RealType>::value
|
||||
&& !std::is_same<RealType, rvar<typename RealType::value_type, DerivativeOrder>>::value
|
||||
&& !has_binary_sub_types<RealType>::value>>
|
||||
{
|
||||
static constexpr std::size_t value = count_rvar_impl<typename T::arg_type, order>::value;
|
||||
static constexpr std::size_t value
|
||||
= count_rvar_impl<typename RealType::arg_type, DerivativeOrder>::value;
|
||||
};
|
||||
template<typename T, size_t order>
|
||||
constexpr std::size_t count_rvars = detail::count_rvar_impl<T, order>::value;
|
||||
template<typename RealType, size_t DerivativeOrder>
|
||||
constexpr std::size_t count_rvars = detail::count_rvar_impl<RealType, DerivativeOrder>::value;
|
||||
|
||||
template<typename T>
|
||||
struct is_expression : std::is_base_of<expression_base, typename std::decay<T>::type>
|
||||
{};
|
||||
|
||||
template<typename T, size_t N>
|
||||
template<typename RealType, size_t N>
|
||||
struct rvar_type_impl
|
||||
{
|
||||
using type = rvar<T, N>;
|
||||
using type = rvar<RealType, N>;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct rvar_type_impl<T, 0>
|
||||
template<typename RealType>
|
||||
struct rvar_type_impl<RealType, 0>
|
||||
{
|
||||
using type = T;
|
||||
using type = RealType;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
@@ -110,63 +119,69 @@ struct rvar_type_impl<T, 0>
|
||||
template<typename T, size_t N>
|
||||
using rvar_t = typename detail::rvar_type_impl<T, N>::type;
|
||||
|
||||
template<typename T, size_t order, class derived_expression>
|
||||
template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
|
||||
struct expression : expression_base
|
||||
{
|
||||
/* @brief
|
||||
* base expression class
|
||||
* */
|
||||
|
||||
using value_type = T;
|
||||
static constexpr size_t order_v = order;
|
||||
using derived_type = derived_expression;
|
||||
using value_type = RealType;
|
||||
static constexpr size_t order_v = DerivativeOrder;
|
||||
using derived_type = DerivedExpression;
|
||||
|
||||
static constexpr size_t num_literals = 0;
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
inner_t evaluate() const { return static_cast<const derived_expression *>(this)->evaluate(); }
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
inner_t evaluate() const { return static_cast<const DerivedExpression *>(this)->evaluate(); }
|
||||
|
||||
template<size_t arg_index>
|
||||
void propagatex(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
return static_cast<const derived_expression *>(this)->template propagatex<arg_index>(node,
|
||||
adj);
|
||||
}
|
||||
return static_cast<const DerivedExpression *>(this)->template propagatex<arg_index>(node,
|
||||
adj);
|
||||
};
|
||||
};
|
||||
|
||||
template<typename T, size_t order, typename LHS, typename RHS, typename concrete_binary_operation>
|
||||
template<typename RealType,
|
||||
size_t DerivativeOrder,
|
||||
typename LHS,
|
||||
typename RHS,
|
||||
typename ConcreteBinaryOperation>
|
||||
struct abstract_binary_expression
|
||||
: public expression<T,
|
||||
order,
|
||||
abstract_binary_expression<T, order, LHS, RHS, concrete_binary_operation>>
|
||||
: public expression<
|
||||
RealType,
|
||||
DerivativeOrder,
|
||||
abstract_binary_expression<RealType, DerivativeOrder, LHS, RHS, ConcreteBinaryOperation>>
|
||||
{
|
||||
using lhs_type = LHS;
|
||||
using rhs_type = RHS;
|
||||
using value_type = T;
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using value_type = RealType;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
const lhs_type lhs;
|
||||
const rhs_type rhs;
|
||||
|
||||
explicit abstract_binary_expression(const expression<T, order, LHS> &left_hand_expr,
|
||||
const expression<T, order, RHS> &right_hand_expr)
|
||||
explicit abstract_binary_expression(
|
||||
const expression<RealType, DerivativeOrder, LHS> &left_hand_expr,
|
||||
const expression<RealType, DerivativeOrder, RHS> &right_hand_expr)
|
||||
: lhs(static_cast<const LHS &>(left_hand_expr))
|
||||
, rhs(static_cast<const RHS &>(right_hand_expr)){};
|
||||
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return static_cast<const concrete_binary_operation *>(this)->evaluate();
|
||||
return static_cast<const ConcreteBinaryOperation *>(this)->evaluate();
|
||||
};
|
||||
|
||||
template<size_t arg_index>
|
||||
void propagatex(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
const inner_t lv = lhs.evaluate();
|
||||
const inner_t rv = rhs.evaluate();
|
||||
const inner_t v = evaluate();
|
||||
const inner_t partial_l = concrete_binary_operation::left_derivative(lv, rv, v);
|
||||
const inner_t partial_r = concrete_binary_operation::right_derivative(lv, rv, v);
|
||||
inner_t lv = lhs.evaluate();
|
||||
inner_t rv = rhs.evaluate();
|
||||
inner_t v = evaluate();
|
||||
inner_t partial_l = ConcreteBinaryOperation::left_derivative(lv, rv, v);
|
||||
inner_t partial_r = ConcreteBinaryOperation::right_derivative(lv, rv, v);
|
||||
|
||||
constexpr size_t num_lhs_args = detail::count_rvars<LHS, order>;
|
||||
constexpr size_t num_rhs_args = detail::count_rvars<RHS, order>;
|
||||
constexpr size_t num_lhs_args = detail::count_rvars<LHS, DerivativeOrder>;
|
||||
constexpr size_t num_rhs_args = detail::count_rvars<RHS, DerivativeOrder>;
|
||||
|
||||
propagate_lhs<num_lhs_args, arg_index>(node, adj * partial_l);
|
||||
propagate_rhs<num_rhs_args, arg_index + num_lhs_args>(node, adj * partial_r);
|
||||
@@ -178,7 +193,7 @@ private:
|
||||
template<std::size_t num_args,
|
||||
std::size_t arg_index_,
|
||||
typename std::enable_if<(num_args > 0), int>::type = 0>
|
||||
void propagate_lhs(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagate_lhs(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
lhs.template propagatex<arg_index_>(node, adj);
|
||||
}
|
||||
@@ -186,13 +201,13 @@ private:
|
||||
template<std::size_t num_args,
|
||||
std::size_t arg_index_,
|
||||
typename std::enable_if<(num_args == 0), int>::type = 0>
|
||||
void propagate_lhs(gradient_node<T, order> *, inner_t) const
|
||||
void propagate_lhs(gradient_node<RealType, DerivativeOrder> *, inner_t) const
|
||||
{}
|
||||
|
||||
template<std::size_t num_args,
|
||||
std::size_t arg_index_,
|
||||
typename std::enable_if<(num_args > 0), int>::type = 0>
|
||||
void propagate_rhs(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagate_rhs(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
rhs.template propagatex<arg_index_>(node, adj);
|
||||
}
|
||||
@@ -200,32 +215,37 @@ private:
|
||||
template<std::size_t num_args,
|
||||
std::size_t arg_index_,
|
||||
typename std::enable_if<(num_args == 0), int>::type = 0>
|
||||
void propagate_rhs(gradient_node<T, order> *, inner_t) const
|
||||
void propagate_rhs(gradient_node<RealType, DerivativeOrder> *, inner_t) const
|
||||
{}
|
||||
};
|
||||
template<typename T, size_t order, typename ARG, typename concrete_unary_operation>
|
||||
template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteUnaryOperation>
|
||||
|
||||
struct abstract_unary_expression
|
||||
: public expression<T, order, abstract_unary_expression<T, order, ARG, concrete_unary_operation>>
|
||||
: public expression<
|
||||
RealType,
|
||||
DerivativeOrder,
|
||||
abstract_unary_expression<RealType, DerivativeOrder, ARG, ConcreteUnaryOperation>>
|
||||
{
|
||||
using arg_type = ARG;
|
||||
using value_type = T;
|
||||
using inner_t = rvar_t<T, order - 1>;
|
||||
using value_type = RealType;
|
||||
using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
|
||||
const arg_type arg;
|
||||
const T constant;
|
||||
explicit abstract_unary_expression(const expression<T, order, ARG> &arg_expr, const T &constant)
|
||||
const RealType constant;
|
||||
explicit abstract_unary_expression(const expression<RealType, DerivativeOrder, ARG> &arg_expr,
|
||||
const RealType &constant)
|
||||
: arg(static_cast<const ARG &>(arg_expr))
|
||||
, constant(constant){};
|
||||
inner_t evaluate() const
|
||||
{
|
||||
return static_cast<const concrete_unary_operation *>(this)->evaluate();
|
||||
return static_cast<const ConcreteUnaryOperation *>(this)->evaluate();
|
||||
};
|
||||
|
||||
template<size_t arg_index>
|
||||
void propagatex(gradient_node<T, order> *node, inner_t adj) const
|
||||
void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
|
||||
{
|
||||
inner_t argv = arg.evaluate();
|
||||
inner_t v = evaluate();
|
||||
inner_t partial_arg = concrete_unary_operation::derivative(argv, v, constant);
|
||||
inner_t partial_arg = ConcreteUnaryOperation::derivative(argv, v, constant);
|
||||
|
||||
arg.template propagatex<arg_index>(node, adj * partial_arg);
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ public:
|
||||
bool operator!() const noexcept { return storage_ == nullptr; }
|
||||
};
|
||||
/* memory management helps for tape */
|
||||
template<typename T, size_t buffer_size>
|
||||
template<typename RealType, size_t buffer_size>
|
||||
class flat_linear_allocator
|
||||
{
|
||||
/** @brief basically a vector<array<T*, size>>
|
||||
@@ -198,8 +198,8 @@ class flat_linear_allocator
|
||||
public:
|
||||
// store vector of unique pointers to arrays
|
||||
// to avoid vector reference invalidation
|
||||
using buffer_type = std::array<T, buffer_size>;
|
||||
using buffer_ptr = std::unique_ptr<std::array<T, buffer_size>>;
|
||||
using buffer_type = std::array<RealType, buffer_size>;
|
||||
using buffer_ptr = std::unique_ptr<std::array<RealType, buffer_size>>;
|
||||
|
||||
private:
|
||||
std::vector<buffer_ptr> data_;
|
||||
@@ -207,14 +207,16 @@ private:
|
||||
std::vector<size_t> checkpoints_; //{0};
|
||||
|
||||
public:
|
||||
friend class flat_linear_allocator_iterator<flat_linear_allocator<T, buffer_size>, buffer_size>;
|
||||
friend class flat_linear_allocator_iterator<const flat_linear_allocator<T, buffer_size>,
|
||||
friend class flat_linear_allocator_iterator<flat_linear_allocator<RealType, buffer_size>,
|
||||
buffer_size>;
|
||||
using value_type = T;
|
||||
friend class flat_linear_allocator_iterator<const flat_linear_allocator<RealType, buffer_size>,
|
||||
buffer_size>;
|
||||
using value_type = RealType;
|
||||
using iterator
|
||||
= flat_linear_allocator_iterator<flat_linear_allocator<T, buffer_size>, buffer_size>;
|
||||
= flat_linear_allocator_iterator<flat_linear_allocator<RealType, buffer_size>, buffer_size>;
|
||||
using const_iterator
|
||||
= flat_linear_allocator_iterator<const flat_linear_allocator<T, buffer_size>, buffer_size>;
|
||||
= flat_linear_allocator_iterator<const flat_linear_allocator<RealType, buffer_size>,
|
||||
buffer_size>;
|
||||
|
||||
size_t buffer_id() const noexcept { return total_size_ / buffer_size; }
|
||||
size_t item_id() const noexcept { return total_size_ % buffer_size; }
|
||||
@@ -242,7 +244,7 @@ public:
|
||||
for (size_t i = 0; i < total_size_; ++i) {
|
||||
size_t bid = i / buffer_size;
|
||||
size_t iid = i % buffer_size;
|
||||
(*data_[bid])[iid].~T();
|
||||
(*data_[bid])[iid].~RealType();
|
||||
}
|
||||
}
|
||||
/** @brief
|
||||
@@ -277,7 +279,7 @@ public:
|
||||
void rewind_to_last_checkpoint() { total_size_ = checkpoints_.back(); }
|
||||
void rewind_to_checkpoint_at(size_t index) { total_size_ = checkpoints_[index]; }
|
||||
|
||||
void fill(const T &val)
|
||||
void fill(const RealType &val)
|
||||
{
|
||||
for (size_t i = 0; i < total_size_; ++i) {
|
||||
size_t bid = i / buffer_size;
|
||||
@@ -296,8 +298,8 @@ public:
|
||||
size_t bid = buffer_id();
|
||||
size_t iid = item_id();
|
||||
|
||||
T *ptr = &(*data_[bid])[iid];
|
||||
new (ptr) T();
|
||||
RealType *ptr = &(*data_[bid])[iid];
|
||||
new (ptr) RealType();
|
||||
++total_size_;
|
||||
return iterator(this, total_size_ - 1);
|
||||
};
|
||||
@@ -312,8 +314,8 @@ public:
|
||||
}
|
||||
BOOST_MATH_ASSERT(buffer_id() < data_.size());
|
||||
BOOST_MATH_ASSERT(item_id() < buffer_size);
|
||||
T *ptr = &(*data_[buffer_id()])[item_id()];
|
||||
new (ptr) T(std::forward<Args>(args)...);
|
||||
RealType *ptr = &(*data_[buffer_id()])[item_id()];
|
||||
new (ptr) RealType(std::forward<Args>(args)...);
|
||||
++total_size_;
|
||||
return iterator(this, total_size_ - 1);
|
||||
}
|
||||
@@ -325,27 +327,27 @@ public:
|
||||
size_t bid = buffer_id();
|
||||
size_t iid = item_id();
|
||||
if (iid + n < buffer_size) {
|
||||
T *ptr = &(*data_[bid])[iid];
|
||||
RealType *ptr = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
new (ptr + i) T();
|
||||
new (ptr + i) RealType();
|
||||
}
|
||||
total_size_ += n;
|
||||
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
|
||||
} else {
|
||||
size_t allocs_in_curr_buffer = buffer_size - iid;
|
||||
size_t allocs_in_next_buffer = n - (buffer_size - iid);
|
||||
T *ptr = &(*data_[bid])[iid];
|
||||
RealType *ptr = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < allocs_in_curr_buffer; ++i) {
|
||||
new (ptr + i) T();
|
||||
new (ptr + i) RealType();
|
||||
}
|
||||
allocate_buffer();
|
||||
bid = data_.size() - 1;
|
||||
iid = 0;
|
||||
total_size_ += n;
|
||||
|
||||
T *ptr2 = &(*data_[bid])[iid];
|
||||
RealType *ptr2 = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < allocs_in_next_buffer; i++) {
|
||||
new (ptr2 + i) T();
|
||||
new (ptr2 + i) RealType();
|
||||
}
|
||||
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
|
||||
}
|
||||
@@ -358,27 +360,27 @@ public:
|
||||
size_t bid = buffer_id();
|
||||
size_t iid = item_id();
|
||||
if (iid + n < buffer_size) {
|
||||
T *ptr = &(*data_[bid])[iid];
|
||||
RealType *ptr = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
new (ptr + i) T();
|
||||
new (ptr + i) RealType();
|
||||
}
|
||||
total_size_ += n;
|
||||
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
|
||||
} else {
|
||||
size_t allocs_in_curr_buffer = buffer_size - iid;
|
||||
size_t allocs_in_next_buffer = n - (buffer_size - iid);
|
||||
T *ptr = &(*data_[bid])[iid];
|
||||
RealType *ptr = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < allocs_in_curr_buffer; ++i) {
|
||||
new (ptr + i) T();
|
||||
new (ptr + i) RealType();
|
||||
}
|
||||
allocate_buffer();
|
||||
bid = data_.size() - 1;
|
||||
iid = 0;
|
||||
total_size_ += n;
|
||||
|
||||
T *ptr2 = &(*data_[bid])[iid];
|
||||
RealType *ptr2 = &(*data_[bid])[iid];
|
||||
for (size_t i = 0; i < allocs_in_next_buffer; i++) {
|
||||
new (ptr2 + i) T();
|
||||
new (ptr2 + i) RealType();
|
||||
}
|
||||
return iterator(this, total_size_ - n, total_size_ - n, total_size_);
|
||||
}
|
||||
@@ -405,19 +407,19 @@ public:
|
||||
|
||||
/** @brief searches for item in allocator
|
||||
* only used to find gradient nodes for propagation */
|
||||
iterator find(const T *const item)
|
||||
iterator find(const RealType *const item)
|
||||
{
|
||||
return std::find_if(begin(), end(), [&](const T &val) { return &val == item; });
|
||||
return std::find_if(begin(), end(), [&](const RealType &val) { return &val == item; });
|
||||
}
|
||||
/** @brief vector like access,
|
||||
* currently unused anywhere but very useful for debugging
|
||||
*/
|
||||
T &operator[](std::size_t i)
|
||||
RealType &operator[](std::size_t i)
|
||||
{
|
||||
BOOST_MATH_ASSERT(i < total_size_);
|
||||
return (*data_[i / buffer_size])[i % buffer_size];
|
||||
}
|
||||
const T &operator[](std::size_t i) const
|
||||
const RealType &operator[](std::size_t i) const
|
||||
{
|
||||
BOOST_MATH_ASSERT(i < total_size_);
|
||||
return (*data_[i / buffer_size])[i % buffer_size];
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user