mirror of
https://github.com/boostorg/compute.git
synced 2026-02-20 02:32:15 +00:00
Merge pull request #305 from kylelutz/stable-sort-radix-sort
Use radix_sort() for stable_sort() when possible
This commit is contained in:
@@ -23,6 +23,8 @@
|
||||
#include <boost/compute/algorithm/exclusive_scan.hpp>
|
||||
#include <boost/compute/container/vector.hpp>
|
||||
#include <boost/compute/type_traits/type_name.hpp>
|
||||
#include <boost/compute/type_traits/is_fundamental.hpp>
|
||||
#include <boost/compute/type_traits/is_vector_type.hpp>
|
||||
#include <boost/compute/detail/iterator_range_size.hpp>
|
||||
#include <boost/compute/detail/program_cache.hpp>
|
||||
|
||||
@@ -30,6 +32,16 @@ namespace boost {
|
||||
namespace compute {
|
||||
namespace detail {
|
||||
|
||||
// meta-function returning true if type T is radix-sortable
|
||||
template<class T>
|
||||
struct is_radix_sortable :
|
||||
boost::mpl::and_<
|
||||
typename ::boost::compute::is_fundamental<T>::type,
|
||||
typename boost::mpl::not_<typename is_vector_type<T>::type>::type
|
||||
>
|
||||
{
|
||||
};
|
||||
|
||||
template<size_t N>
|
||||
struct radix_sort_value_type
|
||||
{
|
||||
@@ -77,6 +89,7 @@ const char radix_sort_source[] =
|
||||
"}\n"
|
||||
|
||||
"__kernel void count(__global const T *input,\n"
|
||||
" const uint input_offset,\n"
|
||||
" const uint input_size,\n"
|
||||
" __global uint *global_counts,\n"
|
||||
" __global uint *global_offsets,\n"
|
||||
@@ -95,7 +108,7 @@ const char radix_sort_source[] =
|
||||
|
||||
// reduce local counts
|
||||
" if(gid < input_size){\n"
|
||||
" T value = input[gid];\n"
|
||||
" T value = input[input_offset+gid];\n"
|
||||
" uint bucket = radix(value, low_bit);\n"
|
||||
" atomic_inc(local_counts + bucket);\n"
|
||||
" }\n"
|
||||
@@ -129,16 +142,21 @@ const char radix_sort_source[] =
|
||||
"}\n"
|
||||
|
||||
"__kernel void scatter(__global const T *input,\n"
|
||||
" const uint input_offset,\n"
|
||||
" const uint input_size,\n"
|
||||
" const uint low_bit,\n"
|
||||
" __global const uint *counts,\n"
|
||||
" __global const uint *global_offsets,\n"
|
||||
"#ifndef SORT_BY_KEY\n"
|
||||
" __global T *output)\n"
|
||||
" __global T *output,\n"
|
||||
" const uint output_offset)\n"
|
||||
"#else\n"
|
||||
" __global T *keys_output,\n"
|
||||
" const uint keys_output_offset,\n"
|
||||
" __global T2 *values_input,\n"
|
||||
" __global T2 *values_output)\n"
|
||||
" const uint values_input_offset,\n"
|
||||
" __global T2 *values_output,\n"
|
||||
" const uint values_output_offset)\n"
|
||||
"#endif\n"
|
||||
"{\n"
|
||||
// work-item parameters
|
||||
@@ -150,7 +168,7 @@ const char radix_sort_source[] =
|
||||
" uint bucket;\n"
|
||||
" __local uint local_input[BLOCK_SIZE];\n"
|
||||
" if(gid < input_size){\n"
|
||||
" value = input[gid];\n"
|
||||
" value = input[input_offset+gid];\n"
|
||||
" bucket = radix(value, low_bit);\n"
|
||||
" local_input[lid] = bucket;\n"
|
||||
" }\n"
|
||||
@@ -180,11 +198,12 @@ const char radix_sort_source[] =
|
||||
|
||||
"#ifndef SORT_BY_KEY\n"
|
||||
// write value to output
|
||||
" output[offset + local_offset] = value;\n"
|
||||
" output[output_offset + offset + local_offset] = value;\n"
|
||||
"#else\n"
|
||||
// write key and value if doing sort_by_key
|
||||
" keys_output[offset + local_offset] = value;\n"
|
||||
" values_output[offset + local_offset] = values_input[gid];\n"
|
||||
" keys_output[keys_output_offset+offset + local_offset] = value;\n"
|
||||
" values_output[values_output_offset+offset + local_offset] =\n"
|
||||
" values_input[values_input_offset+gid];\n"
|
||||
"#endif\n"
|
||||
"}\n";
|
||||
|
||||
@@ -194,6 +213,7 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
|
||||
const buffer_iterator<T2> values_first,
|
||||
command_queue &queue)
|
||||
{
|
||||
|
||||
typedef T value_type;
|
||||
typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
|
||||
|
||||
@@ -222,6 +242,7 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
|
||||
std::string cache_key =
|
||||
std::string("radix_sort_") + type_name<value_type>();
|
||||
|
||||
|
||||
if(sort_by_key){
|
||||
cache_key += std::string("_with_") + type_name<T2>();
|
||||
}
|
||||
@@ -264,18 +285,23 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
|
||||
vector<uint_> counts(block_count * k2, context);
|
||||
|
||||
const buffer *input_buffer = &first.get_buffer();
|
||||
uint_ input_offset = first.get_index();
|
||||
const buffer *output_buffer = &output.get_buffer();
|
||||
uint_ output_offset = 0;
|
||||
const buffer *values_input_buffer = &values_first.get_buffer();
|
||||
uint_ values_input_offset = values_first.get_index();
|
||||
const buffer *values_output_buffer = &values_output.get_buffer();
|
||||
uint_ values_output_offset = 0;
|
||||
|
||||
for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
|
||||
// write counts
|
||||
count_kernel.set_arg(0, *input_buffer);
|
||||
count_kernel.set_arg(1, static_cast<uint_>(count));
|
||||
count_kernel.set_arg(2, counts);
|
||||
count_kernel.set_arg(3, offsets);
|
||||
count_kernel.set_arg(4, block_size * sizeof(uint_), 0);
|
||||
count_kernel.set_arg(5, i * k);
|
||||
count_kernel.set_arg(1, input_offset);
|
||||
count_kernel.set_arg(2, static_cast<uint_>(count));
|
||||
count_kernel.set_arg(3, counts);
|
||||
count_kernel.set_arg(4, offsets);
|
||||
count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
|
||||
count_kernel.set_arg(6, i * k);
|
||||
queue.enqueue_1d_range_kernel(count_kernel,
|
||||
0,
|
||||
block_count * block_size,
|
||||
@@ -322,14 +348,18 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
|
||||
|
||||
// scatter values
|
||||
scatter_kernel.set_arg(0, *input_buffer);
|
||||
scatter_kernel.set_arg(1, static_cast<uint_>(count));
|
||||
scatter_kernel.set_arg(2, i * k);
|
||||
scatter_kernel.set_arg(3, counts);
|
||||
scatter_kernel.set_arg(4, offsets);
|
||||
scatter_kernel.set_arg(5, *output_buffer);
|
||||
scatter_kernel.set_arg(1, input_offset);
|
||||
scatter_kernel.set_arg(2, static_cast<uint_>(count));
|
||||
scatter_kernel.set_arg(3, i * k);
|
||||
scatter_kernel.set_arg(4, counts);
|
||||
scatter_kernel.set_arg(5, offsets);
|
||||
scatter_kernel.set_arg(6, *output_buffer);
|
||||
scatter_kernel.set_arg(7, output_offset);
|
||||
if(sort_by_key){
|
||||
scatter_kernel.set_arg(6, *values_input_buffer);
|
||||
scatter_kernel.set_arg(7, *values_output_buffer);
|
||||
scatter_kernel.set_arg(8, *values_input_buffer);
|
||||
scatter_kernel.set_arg(9, values_input_offset);
|
||||
scatter_kernel.set_arg(10, *values_output_buffer);
|
||||
scatter_kernel.set_arg(11, values_output_offset);
|
||||
}
|
||||
queue.enqueue_1d_range_kernel(scatter_kernel,
|
||||
0,
|
||||
@@ -339,6 +369,8 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
|
||||
// swap buffers
|
||||
std::swap(input_buffer, output_buffer);
|
||||
std::swap(values_input_buffer, values_output_buffer);
|
||||
std::swap(input_offset, output_offset);
|
||||
std::swap(values_input_offset, values_output_offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,51 @@
|
||||
|
||||
#include <boost/compute/system.hpp>
|
||||
#include <boost/compute/command_queue.hpp>
|
||||
#include <boost/compute/algorithm/detail/radix_sort.hpp>
|
||||
#include <boost/compute/algorithm/detail/insertion_sort.hpp>
|
||||
#include <boost/compute/algorithm/reverse.hpp>
|
||||
#include <boost/compute/functional/operator.hpp>
|
||||
|
||||
namespace boost {
|
||||
namespace compute {
|
||||
namespace detail {
|
||||
|
||||
template<class Iterator, class Compare>
|
||||
inline void dispatch_stable_sort(Iterator first,
|
||||
Iterator last,
|
||||
Compare compare,
|
||||
command_queue &queue)
|
||||
{
|
||||
::boost::compute::detail::serial_insertion_sort(
|
||||
first, last, compare, queue
|
||||
);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline typename boost::enable_if_c<is_radix_sortable<T>::value>::type
|
||||
dispatch_stable_sort(buffer_iterator<T> first,
|
||||
buffer_iterator<T> last,
|
||||
less<T>,
|
||||
command_queue &queue)
|
||||
{
|
||||
::boost::compute::detail::radix_sort(first, last, queue);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline typename boost::enable_if_c<is_radix_sortable<T>::value>::type
|
||||
dispatch_stable_sort(buffer_iterator<T> first,
|
||||
buffer_iterator<T> last,
|
||||
greater<T>,
|
||||
command_queue &queue)
|
||||
{
|
||||
// radix sort in ascending order
|
||||
::boost::compute::detail::radix_sort(first, last, queue);
|
||||
|
||||
// reverse range to descending order
|
||||
::boost::compute::reverse(first, last, queue);
|
||||
}
|
||||
|
||||
} // end detail namespace
|
||||
|
||||
/// Sorts the values in the range [\p first, \p last) according to
|
||||
/// \p compare. The relative order of identical values is preserved.
|
||||
@@ -30,10 +71,9 @@ inline void stable_sort(Iterator first,
|
||||
Compare compare,
|
||||
command_queue &queue = system::default_queue())
|
||||
{
|
||||
return ::boost::compute::detail::serial_insertion_sort(first,
|
||||
last,
|
||||
compare,
|
||||
queue);
|
||||
::boost::compute::detail::dispatch_stable_sort(
|
||||
first, last, compare, queue
|
||||
);
|
||||
}
|
||||
|
||||
/// \overload
|
||||
@@ -46,7 +86,7 @@ inline void stable_sort(Iterator first,
|
||||
|
||||
::boost::compute::less<value_type> less;
|
||||
|
||||
return ::boost::compute::stable_sort(first, last, less, queue);
|
||||
::boost::compute::stable_sort(first, last, less, queue);
|
||||
}
|
||||
|
||||
} // end compute namespace
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#ifndef BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP
|
||||
#define BOOST_COMPUTE_TYPE_TRAITS_IS_VECTOR_TYPE_HPP
|
||||
|
||||
#include <boost/mpl/bool.hpp>
|
||||
|
||||
#include <boost/compute/type_traits/vector_size.hpp>
|
||||
|
||||
namespace boost {
|
||||
@@ -26,10 +28,8 @@ namespace compute {
|
||||
///
|
||||
/// \see make_vector_type, vector_size
|
||||
template<class T>
|
||||
struct is_vector_type
|
||||
struct is_vector_type : boost::mpl::bool_<vector_size<T>::value != 1>
|
||||
{
|
||||
/// \internal_
|
||||
BOOST_STATIC_CONSTANT(bool, value = (vector_size<T>::value != 1));
|
||||
};
|
||||
|
||||
} // end compute namespace
|
||||
|
||||
Reference in New Issue
Block a user