diff --git a/include/boost/math/distributions/hyperexponential.hpp b/include/boost/math/distributions/hyperexponential.hpp index 40ed787b9..67b5aa4f2 100644 --- a/include/boost/math/distributions/hyperexponential.hpp +++ b/include/boost/math/distributions/hyperexponential.hpp @@ -274,12 +274,33 @@ class hyperexponential_distribution PolicyT()); } + private: template + class is_iterator + { + private: + using yes = char; + struct no { char x[2]; }; + + // Iterators only require pre-increment and dereference operator (24.2.2) + template (), void(), ++std::declval(), void())> + static yes test(const U&&); + + template + static no test(...); + + public: + static constexpr bool value = (sizeof(test(0)) == sizeof(char)); + }; + + + + // Two arg constructor from 2 ranges, we SFINAE this out of existence if // either argument type is incrementable as in that case the type is // probably an iterator: public: template ::value && - !std::is_pointer::value, bool>::type = true> + typename std::enable_if::value && + !is_iterator::value, bool>::type = true> hyperexponential_distribution(ProbRangeT const& prob_range, RateRangeT const& rate_range) : probs_(std::begin(prob_range), std::end(prob_range)), @@ -300,8 +321,8 @@ class hyperexponential_distribution // Note that we allow different argument types here to allow for // construction from an array plus a pointer into that array. public: template ::value || - std::is_pointer::value, bool>::type = true> + typename std::enable_if::value || + is_iterator::value, bool>::type = true> hyperexponential_distribution(RateIterT const& rate_first, RateIterT2 const& rate_last) : probs_(std::distance(rate_first, rate_last), 1), // will be normalized below diff --git a/test/test_hyperexponential_dist.cpp b/test/test_hyperexponential_dist.cpp index 20436d9bc..fec7fc4ac 100644 --- a/test/test_hyperexponential_dist.cpp +++ b/test/test_hyperexponential_dist.cpp @@ -386,4 +386,19 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(error_cases, RealT, test_types) BOOST_MATH_CHECK_THROW(dist_t(probs2, rates), std::domain_error); BOOST_MATH_CHECK_THROW(dist_t(probs.begin(), probs.begin(), rates.begin(), rates.begin()), std::domain_error); BOOST_MATH_CHECK_THROW(dist_t(rates.begin(), rates.begin()), std::domain_error); + + // Test C++20 ranges + #if (__cplusplus > 202000L || _MSVC_LANG > 202000L) && defined(__cpp_lib_ranges) + #include + #include + + std::array probs_array {1,2}; + std::array rates_array {1,2,3}; + BOOST_MATH_CHECK_THROW(dist_t(std::ranges::begin(probs_array), std::ranges::end(probs_array), std::ranges::begin(rates_array), std::ranges::end(rates_array)), std::domain_error); + + const auto probs_range = probs_array | std::views::all; + const auto rates_range = rates_array | std::views::all; + + BOOST_MATH_CHECK_THROW(dist_t(probs_range, rates_range), std::domain_error); + #endif }