diff --git a/include/boost/compute/algorithm/detail/radix_sort.hpp b/include/boost/compute/algorithm/detail/radix_sort.hpp index eacf8dbf..852644d2 100644 --- a/include/boost/compute/algorithm/detail/radix_sort.hpp +++ b/include/boost/compute/algorithm/detail/radix_sort.hpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include #include @@ -30,6 +32,16 @@ namespace boost { namespace compute { namespace detail { +// meta-function returning true if type T is radix-sortable +template +struct is_radix_sortable : + boost::mpl::and_< + typename ::boost::compute::is_fundamental::type, + typename boost::mpl::not_::type>::type + > +{ +}; + template struct radix_sort_value_type { @@ -77,6 +89,7 @@ const char radix_sort_source[] = "}\n" "__kernel void count(__global const T *input,\n" +" const uint input_offset,\n" " const uint input_size,\n" " __global uint *global_counts,\n" " __global uint *global_offsets,\n" @@ -95,7 +108,7 @@ const char radix_sort_source[] = // reduce local counts " if(gid < input_size){\n" -" T value = input[gid];\n" +" T value = input[input_offset+gid];\n" " uint bucket = radix(value, low_bit);\n" " atomic_inc(local_counts + bucket);\n" " }\n" @@ -129,16 +142,21 @@ const char radix_sort_source[] = "}\n" "__kernel void scatter(__global const T *input,\n" +" const uint input_offset,\n" " const uint input_size,\n" " const uint low_bit,\n" " __global const uint *counts,\n" " __global const uint *global_offsets,\n" "#ifndef SORT_BY_KEY\n" -" __global T *output)\n" +" __global T *output,\n" +" const uint output_offset)\n" "#else\n" " __global T *keys_output,\n" +" const uint keys_output_offset,\n" " __global T2 *values_input,\n" -" __global T2 *values_output)\n" +" const uint values_input_offset,\n" +" __global T2 *values_output,\n" +" const uint values_output_offset)\n" "#endif\n" "{\n" // work-item parameters @@ -150,7 +168,7 @@ const char radix_sort_source[] = " uint bucket;\n" " __local uint local_input[BLOCK_SIZE];\n" " if(gid < input_size){\n" -" value = input[gid];\n" +" value = input[input_offset+gid];\n" " bucket = radix(value, low_bit);\n" " local_input[lid] = bucket;\n" " }\n" @@ -180,11 +198,12 @@ const char radix_sort_source[] = "#ifndef SORT_BY_KEY\n" // write value to output -" output[offset + local_offset] = value;\n" +" output[output_offset + offset + local_offset] = value;\n" "#else\n" // write key and value if doing sort_by_key -" keys_output[offset + local_offset] = value;\n" -" values_output[offset + local_offset] = values_input[gid];\n" +" keys_output[keys_output_offset+offset + local_offset] = value;\n" +" values_output[values_output_offset+offset + local_offset] =\n" +" values_input[values_input_offset+gid];\n" "#endif\n" "}\n"; @@ -194,6 +213,7 @@ inline void radix_sort_impl(const buffer_iterator first, const buffer_iterator values_first, command_queue &queue) { + typedef T value_type; typedef typename radix_sort_value_type::type sort_type; @@ -222,6 +242,7 @@ inline void radix_sort_impl(const buffer_iterator first, std::string cache_key = std::string("radix_sort_") + type_name(); + if(sort_by_key){ cache_key += std::string("_with_") + type_name(); } @@ -264,18 +285,23 @@ inline void radix_sort_impl(const buffer_iterator first, vector counts(block_count * k2, context); const buffer *input_buffer = &first.get_buffer(); + uint_ input_offset = first.get_index(); const buffer *output_buffer = &output.get_buffer(); + uint_ output_offset = 0; const buffer *values_input_buffer = &values_first.get_buffer(); + uint_ values_input_offset = values_first.get_index(); const buffer *values_output_buffer = &values_output.get_buffer(); + uint_ values_output_offset = 0; for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){ // write counts count_kernel.set_arg(0, *input_buffer); - count_kernel.set_arg(1, static_cast(count)); - count_kernel.set_arg(2, counts); - count_kernel.set_arg(3, offsets); - count_kernel.set_arg(4, block_size * sizeof(uint_), 0); - count_kernel.set_arg(5, i * k); + count_kernel.set_arg(1, input_offset); + count_kernel.set_arg(2, static_cast(count)); + count_kernel.set_arg(3, counts); + count_kernel.set_arg(4, offsets); + count_kernel.set_arg(5, block_size * sizeof(uint_), 0); + count_kernel.set_arg(6, i * k); queue.enqueue_1d_range_kernel(count_kernel, 0, block_count * block_size, @@ -322,14 +348,18 @@ inline void radix_sort_impl(const buffer_iterator first, // scatter values scatter_kernel.set_arg(0, *input_buffer); - scatter_kernel.set_arg(1, static_cast(count)); - scatter_kernel.set_arg(2, i * k); - scatter_kernel.set_arg(3, counts); - scatter_kernel.set_arg(4, offsets); - scatter_kernel.set_arg(5, *output_buffer); + scatter_kernel.set_arg(1, input_offset); + scatter_kernel.set_arg(2, static_cast(count)); + scatter_kernel.set_arg(3, i * k); + scatter_kernel.set_arg(4, counts); + scatter_kernel.set_arg(5, offsets); + scatter_kernel.set_arg(6, *output_buffer); + scatter_kernel.set_arg(7, output_offset); if(sort_by_key){ - scatter_kernel.set_arg(6, *values_input_buffer); - scatter_kernel.set_arg(7, *values_output_buffer); + scatter_kernel.set_arg(8, *values_input_buffer); + scatter_kernel.set_arg(9, values_input_offset); + scatter_kernel.set_arg(10, *values_output_buffer); + scatter_kernel.set_arg(11, values_output_offset); } queue.enqueue_1d_range_kernel(scatter_kernel, 0, @@ -339,6 +369,8 @@ inline void radix_sort_impl(const buffer_iterator first, // swap buffers std::swap(input_buffer, output_buffer); std::swap(values_input_buffer, values_output_buffer); + std::swap(input_offset, output_offset); + std::swap(values_input_offset, values_output_offset); } } diff --git a/include/boost/compute/algorithm/stable_sort.hpp b/include/boost/compute/algorithm/stable_sort.hpp index a42213bb..a2d33a99 100644 --- a/include/boost/compute/algorithm/stable_sort.hpp +++ b/include/boost/compute/algorithm/stable_sort.hpp @@ -15,10 +15,51 @@ #include #include +#include #include +#include +#include namespace boost { namespace compute { +namespace detail { + +template +inline void dispatch_stable_sort(Iterator first, + Iterator last, + Compare compare, + command_queue &queue) +{ + ::boost::compute::detail::serial_insertion_sort( + first, last, compare, queue + ); +} + +template +inline typename boost::enable_if_c::value>::type +dispatch_stable_sort(buffer_iterator first, + buffer_iterator last, + less, + command_queue &queue) +{ + ::boost::compute::detail::radix_sort(first, last, queue); +} + +template +inline typename boost::enable_if_c::value>::type +dispatch_stable_sort(buffer_iterator first, + buffer_iterator last, + greater, + command_queue &queue) +{ + // radix sort in ascending order + ::boost::compute::detail::radix_sort(first, last, queue); + + // reverse range to descending order + ::boost::compute::reverse(first, last, queue); +} + +} // end detail namespace /// Sorts the values in the range [\p first, \p last) according to /// \p compare. The relative order of identical values is preserved. @@ -30,10 +71,9 @@ inline void stable_sort(Iterator first, Compare compare, command_queue &queue = system::default_queue()) { - return ::boost::compute::detail::serial_insertion_sort(first, - last, - compare, - queue); + ::boost::compute::detail::dispatch_stable_sort( + first, last, compare, queue + ); } /// \overload @@ -46,7 +86,7 @@ inline void stable_sort(Iterator first, ::boost::compute::less less; - return ::boost::compute::stable_sort(first, last, less, queue); + ::boost::compute::stable_sort(first, last, less, queue); } } // end compute namespace diff --git a/include/boost/compute/type_traits/is_vector_type.hpp b/include/boost/compute/type_traits/is_vector_type.hpp index 603b5d07..8e2ef1d7 100644 --- a/include/boost/compute/type_traits/is_vector_type.hpp +++ b/include/boost/compute/type_traits/is_vector_type.hpp @@ -11,6 +11,8 @@ #ifndef BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP #define BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP +#include + #include namespace boost { @@ -26,10 +28,8 @@ namespace compute { /// /// \see make_vector_type, vector_size template -struct is_vector_type +struct is_vector_type : boost::mpl::bool_::value != 1> { - /// \internal_ - BOOST_STATIC_CONSTANT(bool, value = (vector_size::value != 1)); }; } // end compute namespace diff --git a/test/test_radix_sort.cpp b/test/test_radix_sort.cpp index ba1d0ad4..f0bf7a8d 100644 --- a/test/test_radix_sort.cpp +++ b/test/test_radix_sort.cpp @@ -177,4 +177,13 @@ BOOST_AUTO_TEST_CASE(sort_double_vector) ); } +BOOST_AUTO_TEST_CASE(sort_partial_vector) +{ + int data[] = { 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 }; + boost::compute::vector vec(data, data + 10, queue); + + boost::compute::detail::radix_sort(vec.begin() + 2, vec.end() - 2, queue); + CHECK_RANGE_EQUAL(int, 10, vec, (9, 8, 2, 3, 4, 5, 6, 7, 1, 0)); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/test_stable_sort.cpp b/test/test_stable_sort.cpp index ab6fc1c4..dde59d3c 100644 --- a/test/test_stable_sort.cpp +++ b/test/test_stable_sort.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -19,16 +20,73 @@ #include "check_macros.hpp" #include "context_setup.hpp" +namespace compute = boost::compute; + BOOST_AUTO_TEST_CASE(sort_int_vector) { int data[] = { -4, 152, -5000, 963, 75321, -456, 0, 1112 }; - boost::compute::vector vector(data, data + 8); + compute::vector vector(data, data + 8, queue); BOOST_CHECK_EQUAL(vector.size(), size_t(8)); - BOOST_CHECK(boost::compute::is_sorted(vector.begin(), vector.end()) == false); + BOOST_CHECK(compute::is_sorted(vector.begin(), vector.end(), queue) == false); - boost::compute::stable_sort(vector.begin(), vector.end()); - BOOST_CHECK(boost::compute::is_sorted(vector.begin(), vector.end()) == true); + compute::stable_sort(vector.begin(), vector.end(), queue); + BOOST_CHECK(compute::is_sorted(vector.begin(), vector.end(), queue) == true); CHECK_RANGE_EQUAL(int, 8, vector, (-5000, -456, -4, 0, 152, 963, 1112, 75321)); + + // sort reversed + compute::stable_sort(vector.begin(), vector.end(), compute::greater(), queue); + CHECK_RANGE_EQUAL(int, 8, vector, (75321, 1112, 963, 152, 0, -4, -456, -5000)); +} + +BOOST_AUTO_TEST_CASE(sort_int2) +{ + using compute::int2_; + + // device vector of int2's + compute::vector vec(context); + vec.push_back(int2_(2, 1), queue); + vec.push_back(int2_(2, 2), queue); + vec.push_back(int2_(1, 2), queue); + vec.push_back(int2_(1, 1), queue); + + // function comparing the first component of each int2 + BOOST_COMPUTE_FUNCTION(bool, compare_first, (int2_ a, int2_ b), + { + return a.x < b.x; + }); + + // ensure vector is not sorted + BOOST_CHECK(compute::is_sorted(vec.begin(), vec.end(), compare_first, queue) == false); + + // sort elements based on their first component + compute::stable_sort(vec.begin(), vec.end(), compare_first, queue); + + // ensure vector is now sorted + BOOST_CHECK(compute::is_sorted(vec.begin(), vec.end(), compare_first, queue) == true); + + // check sorted vector order + std::vector result(vec.size()); + compute::copy(vec.begin(), vec.end(), result.begin(), queue); + BOOST_CHECK_EQUAL(result[0], int2_(1, 2)); + BOOST_CHECK_EQUAL(result[1], int2_(1, 1)); + BOOST_CHECK_EQUAL(result[2], int2_(2, 1)); + BOOST_CHECK_EQUAL(result[3], int2_(2, 2)); + + // function comparing the second component of each int2 + BOOST_COMPUTE_FUNCTION(bool, compare_second, (int2_ a, int2_ b), + { + return a.y < b.y; + }); + + // sort elements based on their second component + compute::stable_sort(vec.begin(), vec.end(), compare_second, queue); + + // check sorted vector order + compute::copy(vec.begin(), vec.end(), result.begin(), queue); + BOOST_CHECK_EQUAL(result[0], int2_(1, 1)); + BOOST_CHECK_EQUAL(result[1], int2_(2, 1)); + BOOST_CHECK_EQUAL(result[2], int2_(1, 2)); + BOOST_CHECK_EQUAL(result[3], int2_(2, 2)); } BOOST_AUTO_TEST_SUITE_END()