2
0
mirror of https://github.com/boostorg/math.git synced 2026-01-30 08:02:11 +00:00

Allow for constexpr rsqrt in certain contexts

This commit is contained in:
Matt Borland
2021-07-10 15:02:53 +03:00
parent 9677ed0a0a
commit 7c8ee612d3
3 changed files with 36 additions and 24 deletions

View File

@@ -8,37 +8,42 @@
#include <cmath>
#include <type_traits>
#include <limits>
#include <boost/math/special_functions/sqrt.hpp>
namespace boost::math {
namespace boost { namespace math {
template<typename Real>
template <typename Real, typename std::enable_if<std::is_same<Real, float>::value ||
std::is_same<Real, double>::value ||
std::is_same<Real, long double>::value, bool>::type = true>
inline constexpr Real rsqrt(Real const & x)
{
return 1 / boost::math::sqrt(x);
}
template<typename Real, typename std::enable_if<!std::is_same<Real, float>::value &&
!std::is_same<Real, double>::value &&
!std::is_same<Real, long double>::value, bool>::type = true>
inline Real rsqrt(Real const & x)
{
using std::sqrt;
if constexpr (std::is_same_v<Real, float> || std::is_same_v<Real, double> || std::is_same_v<Real, long double>)
{
// if it's so tiny it rounds to 0 as long double,
// no performance gains are possible:
if (x < std::numeric_limits<long double>::denorm_min() || x > (std::numeric_limits<long double>::max)()) {
return 1/sqrt(x);
}
else
{
// if it's so tiny it rounds to 0 as long double,
// no performance gains are possible:
if (x < std::numeric_limits<long double>::denorm_min() || x > (std::numeric_limits<long double>::max)()) {
return 1/sqrt(x);
}
Real x0 = 1/sqrt(static_cast<long double>(x));
// Divide by 512 for leeway:
Real s = sqrt(std::numeric_limits<Real>::epsilon())*x0/512;
Real x1 = x0 + x0*(1-x*x0*x0)/2;
while(abs(x1 - x0) > s) {
x0 = x1;
x1 = x0 + x0*(1-x*x0*x0)/2;
}
// Final iteration get ~2ULPs:
return x1 + x1*(1-x*x1*x1)/2;;
Real x0 = 1/sqrt(static_cast<long double>(x));
// Divide by 512 for leeway:
Real s = sqrt(std::numeric_limits<Real>::epsilon())*x0/512;
Real x1 = x0 + x0*(1-x*x0*x0)/2;
while(abs(x1 - x0) > s) {
x0 = x1;
x1 = x0 + x0*(1-x*x0*x0)/2;
}
// Final iteration get ~2ULPs:
return x1 + x1*(1-x*x1*x1)/2;
}
}
}}
#endif

View File

@@ -32,12 +32,19 @@ inline constexpr Real sqrt_impl(Real x)
return sqrt_impl_1(x, x > 1 ? x : Real(1));
}
// std::isnan is not constexpr according to the standard
template <typename Real>
inline constexpr bool is_nan(Real x)
{
return x != x;
}
} // namespace detail
template <typename Real, typename std::enable_if<std::is_floating_point<Real>::value, bool>::type = true>
inline constexpr Real sqrt(Real x)
{
return detail::sqrt_impl<Real>(x);
return detail::is_nan(x) ? NAN : detail::sqrt_impl<Real>(x);
}
template <typename Z, typename std::enable_if<std::is_integral<Z>::value, bool>::type = true>

View File

@@ -67,7 +67,7 @@ void test_rsqrt()
x = (std::numeric_limits<Real>::max)();
expected = 1/sqrt(x);
computed = rsqrt(x);
if (!CHECK_EQUAL(expected, computed)) {
if (!CHECK_ULP_CLOSE(expected, computed, 2)) {
std::cerr << "Reciprocal square root of std::numeric_limits<Real>::max() not correctly computed.\n";
}