2
0
mirror of https://github.com/boostorg/compute.git synced 2026-02-20 02:32:15 +00:00

Merge pull request #305 from kylelutz/stable-sort-radix-sort

Use radix_sort() for stable_sort() when possible
This commit is contained in:
Kyle Lutz
2014-11-10 21:24:14 -08:00
5 changed files with 170 additions and 31 deletions

View File

@@ -23,6 +23,8 @@
#include <boost/compute/algorithm/exclusive_scan.hpp>
#include <boost/compute/container/vector.hpp>
#include <boost/compute/type_traits/type_name.hpp>
#include <boost/compute/type_traits/is_fundamental.hpp>
#include <boost/compute/type_traits/is_vector_type.hpp>
#include <boost/compute/detail/iterator_range_size.hpp>
#include <boost/compute/detail/program_cache.hpp>
@@ -30,6 +32,16 @@ namespace boost {
namespace compute {
namespace detail {
// meta-function returning true if type T is radix-sortable
template<class T>
struct is_radix_sortable :
boost::mpl::and_<
typename ::boost::compute::is_fundamental<T>::type,
typename boost::mpl::not_<typename is_vector_type<T>::type>::type
>
{
};
template<size_t N>
struct radix_sort_value_type
{
@@ -77,6 +89,7 @@ const char radix_sort_source[] =
"}\n"
"__kernel void count(__global const T *input,\n"
" const uint input_offset,\n"
" const uint input_size,\n"
" __global uint *global_counts,\n"
" __global uint *global_offsets,\n"
@@ -95,7 +108,7 @@ const char radix_sort_source[] =
// reduce local counts
" if(gid < input_size){\n"
" T value = input[gid];\n"
" T value = input[input_offset+gid];\n"
" uint bucket = radix(value, low_bit);\n"
" atomic_inc(local_counts + bucket);\n"
" }\n"
@@ -129,16 +142,21 @@ const char radix_sort_source[] =
"}\n"
"__kernel void scatter(__global const T *input,\n"
" const uint input_offset,\n"
" const uint input_size,\n"
" const uint low_bit,\n"
" __global const uint *counts,\n"
" __global const uint *global_offsets,\n"
"#ifndef SORT_BY_KEY\n"
" __global T *output)\n"
" __global T *output,\n"
" const uint output_offset)\n"
"#else\n"
" __global T *keys_output,\n"
" const uint keys_output_offset,\n"
" __global T2 *values_input,\n"
" __global T2 *values_output)\n"
" const uint values_input_offset,\n"
" __global T2 *values_output,\n"
" const uint values_output_offset)\n"
"#endif\n"
"{\n"
// work-item parameters
@@ -150,7 +168,7 @@ const char radix_sort_source[] =
" uint bucket;\n"
" __local uint local_input[BLOCK_SIZE];\n"
" if(gid < input_size){\n"
" value = input[gid];\n"
" value = input[input_offset+gid];\n"
" bucket = radix(value, low_bit);\n"
" local_input[lid] = bucket;\n"
" }\n"
@@ -180,11 +198,12 @@ const char radix_sort_source[] =
"#ifndef SORT_BY_KEY\n"
// write value to output
" output[offset + local_offset] = value;\n"
" output[output_offset + offset + local_offset] = value;\n"
"#else\n"
// write key and value if doing sort_by_key
" keys_output[offset + local_offset] = value;\n"
" values_output[offset + local_offset] = values_input[gid];\n"
" keys_output[keys_output_offset+offset + local_offset] = value;\n"
" values_output[values_output_offset+offset + local_offset] =\n"
" values_input[values_input_offset+gid];\n"
"#endif\n"
"}\n";
@@ -194,6 +213,7 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
const buffer_iterator<T2> values_first,
command_queue &queue)
{
typedef T value_type;
typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
@@ -222,6 +242,7 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
std::string cache_key =
std::string("radix_sort_") + type_name<value_type>();
if(sort_by_key){
cache_key += std::string("_with_") + type_name<T2>();
}
@@ -264,18 +285,23 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
vector<uint_> counts(block_count * k2, context);
const buffer *input_buffer = &first.get_buffer();
uint_ input_offset = first.get_index();
const buffer *output_buffer = &output.get_buffer();
uint_ output_offset = 0;
const buffer *values_input_buffer = &values_first.get_buffer();
uint_ values_input_offset = values_first.get_index();
const buffer *values_output_buffer = &values_output.get_buffer();
uint_ values_output_offset = 0;
for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
// write counts
count_kernel.set_arg(0, *input_buffer);
count_kernel.set_arg(1, static_cast<uint_>(count));
count_kernel.set_arg(2, counts);
count_kernel.set_arg(3, offsets);
count_kernel.set_arg(4, block_size * sizeof(uint_), 0);
count_kernel.set_arg(5, i * k);
count_kernel.set_arg(1, input_offset);
count_kernel.set_arg(2, static_cast<uint_>(count));
count_kernel.set_arg(3, counts);
count_kernel.set_arg(4, offsets);
count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
count_kernel.set_arg(6, i * k);
queue.enqueue_1d_range_kernel(count_kernel,
0,
block_count * block_size,
@@ -322,14 +348,18 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
// scatter values
scatter_kernel.set_arg(0, *input_buffer);
scatter_kernel.set_arg(1, static_cast<uint_>(count));
scatter_kernel.set_arg(2, i * k);
scatter_kernel.set_arg(3, counts);
scatter_kernel.set_arg(4, offsets);
scatter_kernel.set_arg(5, *output_buffer);
scatter_kernel.set_arg(1, input_offset);
scatter_kernel.set_arg(2, static_cast<uint_>(count));
scatter_kernel.set_arg(3, i * k);
scatter_kernel.set_arg(4, counts);
scatter_kernel.set_arg(5, offsets);
scatter_kernel.set_arg(6, *output_buffer);
scatter_kernel.set_arg(7, output_offset);
if(sort_by_key){
scatter_kernel.set_arg(6, *values_input_buffer);
scatter_kernel.set_arg(7, *values_output_buffer);
scatter_kernel.set_arg(8, *values_input_buffer);
scatter_kernel.set_arg(9, values_input_offset);
scatter_kernel.set_arg(10, *values_output_buffer);
scatter_kernel.set_arg(11, values_output_offset);
}
queue.enqueue_1d_range_kernel(scatter_kernel,
0,
@@ -339,6 +369,8 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
// swap buffers
std::swap(input_buffer, output_buffer);
std::swap(values_input_buffer, values_output_buffer);
std::swap(input_offset, output_offset);
std::swap(values_input_offset, values_output_offset);
}
}

View File

@@ -15,10 +15,51 @@
#include <boost/compute/system.hpp>
#include <boost/compute/command_queue.hpp>
#include <boost/compute/algorithm/detail/radix_sort.hpp>
#include <boost/compute/algorithm/detail/insertion_sort.hpp>
#include <boost/compute/algorithm/reverse.hpp>
#include <boost/compute/functional/operator.hpp>
namespace boost {
namespace compute {
namespace detail {
template<class Iterator, class Compare>
inline void dispatch_stable_sort(Iterator first,
Iterator last,
Compare compare,
command_queue &queue)
{
::boost::compute::detail::serial_insertion_sort(
first, last, compare, queue
);
}
template<class T>
inline typename boost::enable_if_c<is_radix_sortable<T>::value>::type
dispatch_stable_sort(buffer_iterator<T> first,
buffer_iterator<T> last,
less<T>,
command_queue &queue)
{
::boost::compute::detail::radix_sort(first, last, queue);
}
template<class T>
inline typename boost::enable_if_c<is_radix_sortable<T>::value>::type
dispatch_stable_sort(buffer_iterator<T> first,
buffer_iterator<T> last,
greater<T>,
command_queue &queue)
{
// radix sort in ascending order
::boost::compute::detail::radix_sort(first, last, queue);
// reverse range to descending order
::boost::compute::reverse(first, last, queue);
}
} // end detail namespace
/// Sorts the values in the range [\p first, \p last) according to
/// \p compare. The relative order of identical values is preserved.
@@ -30,10 +71,9 @@ inline void stable_sort(Iterator first,
Compare compare,
command_queue &queue = system::default_queue())
{
return ::boost::compute::detail::serial_insertion_sort(first,
last,
compare,
queue);
::boost::compute::detail::dispatch_stable_sort(
first, last, compare, queue
);
}
/// \overload
@@ -46,7 +86,7 @@ inline void stable_sort(Iterator first,
::boost::compute::less<value_type> less;
return ::boost::compute::stable_sort(first, last, less, queue);
::boost::compute::stable_sort(first, last, less, queue);
}
} // end compute namespace

View File

@@ -11,6 +11,8 @@
#ifndef BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP
#define BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP
#include <boost/mpl/bool.hpp>
#include <boost/compute/type_traits/vector_size.hpp>
namespace boost {
@@ -26,10 +28,8 @@ namespace compute {
///
/// \see make_vector_type, vector_size
template<class T>
struct is_vector_type
struct is_vector_type : boost::mpl::bool_<vector_size<T>::value != 1>
{
/// \internal_
BOOST_STATIC_CONSTANT(bool, value = (vector_size<T>::value != 1));
};
} // end compute namespace