diff --git a/include/boost/math/special_functions/rsqrt.hpp b/include/boost/math/special_functions/rsqrt.hpp index 3b8be8235..4e3ef5b46 100644 --- a/include/boost/math/special_functions/rsqrt.hpp +++ b/include/boost/math/special_functions/rsqrt.hpp @@ -8,37 +8,42 @@ #include #include #include +#include -namespace boost::math { +namespace boost { namespace math { -template +template ::value || + std::is_same::value || + std::is_same::value, bool>::type = true> +inline constexpr Real rsqrt(Real const & x) +{ + return 1 / boost::math::sqrt(x); +} + +template::value && + !std::is_same::value && + !std::is_same::value, bool>::type = true> inline Real rsqrt(Real const & x) { using std::sqrt; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) - { + + // if it's so tiny it rounds to 0 as long double, + // no performance gains are possible: + if (x < std::numeric_limits::denorm_min() || x > (std::numeric_limits::max)()) { return 1/sqrt(x); } - else - { - // if it's so tiny it rounds to 0 as long double, - // no performance gains are possible: - if (x < std::numeric_limits::denorm_min() || x > (std::numeric_limits::max)()) { - return 1/sqrt(x); - } - Real x0 = 1/sqrt(static_cast(x)); - // Divide by 512 for leeway: - Real s = sqrt(std::numeric_limits::epsilon())*x0/512; - Real x1 = x0 + x0*(1-x*x0*x0)/2; - while(abs(x1 - x0) > s) { - x0 = x1; - x1 = x0 + x0*(1-x*x0*x0)/2; - } - // Final iteration get ~2ULPs: - return x1 + x1*(1-x*x1*x1)/2;; + Real x0 = 1/sqrt(static_cast(x)); + // Divide by 512 for leeway: + Real s = sqrt(std::numeric_limits::epsilon())*x0/512; + Real x1 = x0 + x0*(1-x*x0*x0)/2; + while(abs(x1 - x0) > s) { + x0 = x1; + x1 = x0 + x0*(1-x*x0*x0)/2; } + // Final iteration get ~2ULPs: + return x1 + x1*(1-x*x1*x1)/2; } -} +}} #endif diff --git a/include/boost/math/special_functions/sqrt.hpp b/include/boost/math/special_functions/sqrt.hpp index fe9fb0d90..cb6f34ebe 100644 --- a/include/boost/math/special_functions/sqrt.hpp +++ b/include/boost/math/special_functions/sqrt.hpp @@ -32,12 +32,19 @@ inline constexpr Real sqrt_impl(Real x) return sqrt_impl_1(x, x > 1 ? x : Real(1)); } +// std::isnan is not constexpr according to the standard +template +inline constexpr bool is_nan(Real x) +{ + return x != x; +} + } // namespace detail template ::value, bool>::type = true> inline constexpr Real sqrt(Real x) { - return detail::sqrt_impl(x); + return detail::is_nan(x) ? NAN : detail::sqrt_impl(x); } template ::value, bool>::type = true> diff --git a/test/rsqrt_test.cpp b/test/rsqrt_test.cpp index f41402f15..89c646dd4 100644 --- a/test/rsqrt_test.cpp +++ b/test/rsqrt_test.cpp @@ -67,7 +67,7 @@ void test_rsqrt() x = (std::numeric_limits::max)(); expected = 1/sqrt(x); computed = rsqrt(x); - if (!CHECK_EQUAL(expected, computed)) { + if (!CHECK_ULP_CLOSE(expected, computed, 2)) { std::cerr << "Reciprocal square root of std::numeric_limits::max() not correctly computed.\n"; }