From 88d19723e33a8373c48ebaf849bf64c33bc4d8dd Mon Sep 17 00:00:00 2001 From: Kyle Lutz Date: Sun, 9 Nov 2014 11:06:47 -0800 Subject: [PATCH] Use radix_sort() for stable_sort() when possible --- .../compute/algorithm/detail/radix_sort.hpp | 12 ++++ .../boost/compute/algorithm/stable_sort.hpp | 50 ++++++++++++-- .../compute/type_traits/is_vector_type.hpp | 6 +- test/test_stable_sort.cpp | 66 +++++++++++++++++-- 4 files changed, 122 insertions(+), 12 deletions(-) diff --git a/include/boost/compute/algorithm/detail/radix_sort.hpp b/include/boost/compute/algorithm/detail/radix_sort.hpp index 15456f0a..852644d2 100644 --- a/include/boost/compute/algorithm/detail/radix_sort.hpp +++ b/include/boost/compute/algorithm/detail/radix_sort.hpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include #include @@ -30,6 +32,16 @@ namespace boost { namespace compute { namespace detail { +// meta-function returning true if type T is radix-sortable +template +struct is_radix_sortable : + boost::mpl::and_< + typename ::boost::compute::is_fundamental::type, + typename boost::mpl::not_::type>::type + > +{ +}; + template struct radix_sort_value_type { diff --git a/include/boost/compute/algorithm/stable_sort.hpp b/include/boost/compute/algorithm/stable_sort.hpp index a42213bb..a2d33a99 100644 --- a/include/boost/compute/algorithm/stable_sort.hpp +++ b/include/boost/compute/algorithm/stable_sort.hpp @@ -15,10 +15,51 @@ #include #include +#include #include +#include +#include namespace boost { namespace compute { +namespace detail { + +template +inline void dispatch_stable_sort(Iterator first, + Iterator last, + Compare compare, + command_queue &queue) +{ + ::boost::compute::detail::serial_insertion_sort( + first, last, compare, queue + ); +} + +template +inline typename boost::enable_if_c::value>::type +dispatch_stable_sort(buffer_iterator first, + buffer_iterator last, + less, + command_queue &queue) +{ + ::boost::compute::detail::radix_sort(first, last, queue); +} + +template +inline typename boost::enable_if_c::value>::type +dispatch_stable_sort(buffer_iterator first, + buffer_iterator last, + greater, + command_queue &queue) +{ + // radix sort in ascending order + ::boost::compute::detail::radix_sort(first, last, queue); + + // reverse range to descending order + ::boost::compute::reverse(first, last, queue); +} + +} // end detail namespace /// Sorts the values in the range [\p first, \p last) according to /// \p compare. The relative order of identical values is preserved. @@ -30,10 +71,9 @@ inline void stable_sort(Iterator first, Compare compare, command_queue &queue = system::default_queue()) { - return ::boost::compute::detail::serial_insertion_sort(first, - last, - compare, - queue); + ::boost::compute::detail::dispatch_stable_sort( + first, last, compare, queue + ); } /// \overload @@ -46,7 +86,7 @@ inline void stable_sort(Iterator first, ::boost::compute::less less; - return ::boost::compute::stable_sort(first, last, less, queue); + ::boost::compute::stable_sort(first, last, less, queue); } } // end compute namespace diff --git a/include/boost/compute/type_traits/is_vector_type.hpp b/include/boost/compute/type_traits/is_vector_type.hpp index 603b5d07..8e2ef1d7 100644 --- a/include/boost/compute/type_traits/is_vector_type.hpp +++ b/include/boost/compute/type_traits/is_vector_type.hpp @@ -11,6 +11,8 @@ #ifndef BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP #define BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP +#include + #include namespace boost { @@ -26,10 +28,8 @@ namespace compute { /// /// \see make_vector_type, vector_size template -struct is_vector_type +struct is_vector_type : boost::mpl::bool_::value != 1> { - /// \internal_ - BOOST_STATIC_CONSTANT(bool, value = (vector_size::value != 1)); }; } // end compute namespace diff --git a/test/test_stable_sort.cpp b/test/test_stable_sort.cpp index ab6fc1c4..dde59d3c 100644 --- a/test/test_stable_sort.cpp +++ b/test/test_stable_sort.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -19,16 +20,73 @@ #include "check_macros.hpp" #include "context_setup.hpp" +namespace compute = boost::compute; + BOOST_AUTO_TEST_CASE(sort_int_vector) { int data[] = { -4, 152, -5000, 963, 75321, -456, 0, 1112 }; - boost::compute::vector vector(data, data + 8); + compute::vector vector(data, data + 8, queue); BOOST_CHECK_EQUAL(vector.size(), size_t(8)); - BOOST_CHECK(boost::compute::is_sorted(vector.begin(), vector.end()) == false); + BOOST_CHECK(compute::is_sorted(vector.begin(), vector.end(), queue) == false); - boost::compute::stable_sort(vector.begin(), vector.end()); - BOOST_CHECK(boost::compute::is_sorted(vector.begin(), vector.end()) == true); + compute::stable_sort(vector.begin(), vector.end(), queue); + BOOST_CHECK(compute::is_sorted(vector.begin(), vector.end(), queue) == true); CHECK_RANGE_EQUAL(int, 8, vector, (-5000, -456, -4, 0, 152, 963, 1112, 75321)); + + // sort reversed + compute::stable_sort(vector.begin(), vector.end(), compute::greater(), queue); + CHECK_RANGE_EQUAL(int, 8, vector, (75321, 1112, 963, 152, 0, -4, -456, -5000)); +} + +BOOST_AUTO_TEST_CASE(sort_int2) +{ + using compute::int2_; + + // device vector of int2's + compute::vector vec(context); + vec.push_back(int2_(2, 1), queue); + vec.push_back(int2_(2, 2), queue); + vec.push_back(int2_(1, 2), queue); + vec.push_back(int2_(1, 1), queue); + + // function comparing the first component of each int2 + BOOST_COMPUTE_FUNCTION(bool, compare_first, (int2_ a, int2_ b), + { + return a.x < b.x; + }); + + // ensure vector is not sorted + BOOST_CHECK(compute::is_sorted(vec.begin(), vec.end(), compare_first, queue) == false); + + // sort elements based on their first component + compute::stable_sort(vec.begin(), vec.end(), compare_first, queue); + + // ensure vector is now sorted + BOOST_CHECK(compute::is_sorted(vec.begin(), vec.end(), compare_first, queue) == true); + + // check sorted vector order + std::vector result(vec.size()); + compute::copy(vec.begin(), vec.end(), result.begin(), queue); + BOOST_CHECK_EQUAL(result[0], int2_(1, 2)); + BOOST_CHECK_EQUAL(result[1], int2_(1, 1)); + BOOST_CHECK_EQUAL(result[2], int2_(2, 1)); + BOOST_CHECK_EQUAL(result[3], int2_(2, 2)); + + // function comparing the second component of each int2 + BOOST_COMPUTE_FUNCTION(bool, compare_second, (int2_ a, int2_ b), + { + return a.y < b.y; + }); + + // sort elements based on their second component + compute::stable_sort(vec.begin(), vec.end(), compare_second, queue); + + // check sorted vector order + compute::copy(vec.begin(), vec.end(), result.begin(), queue); + BOOST_CHECK_EQUAL(result[0], int2_(1, 1)); + BOOST_CHECK_EQUAL(result[1], int2_(2, 1)); + BOOST_CHECK_EQUAL(result[2], int2_(1, 2)); + BOOST_CHECK_EQUAL(result[3], int2_(2, 2)); } BOOST_AUTO_TEST_SUITE_END()