From faddbf236845708d5e176c9c039d1da0a582bf35 Mon Sep 17 00:00:00 2001 From: Kyle Lutz Date: Tue, 2 Dec 2014 21:49:51 -0800 Subject: [PATCH] Refactor dispatch_sort() function --- .../compute/algorithm/detail/fixed_sort.hpp | 126 ------------------ include/boost/compute/algorithm/sort.hpp | 106 ++++++++++----- 2 files changed, 75 insertions(+), 157 deletions(-) delete mode 100644 include/boost/compute/algorithm/detail/fixed_sort.hpp diff --git a/include/boost/compute/algorithm/detail/fixed_sort.hpp b/include/boost/compute/algorithm/detail/fixed_sort.hpp deleted file mode 100644 index 7f264719..00000000 --- a/include/boost/compute/algorithm/detail/fixed_sort.hpp +++ /dev/null @@ -1,126 +0,0 @@ -//---------------------------------------------------------------------------// -// Copyright (c) 2013 Kyle Lutz -// -// Distributed under the Boost Software License, Version 1.0 -// See accompanying file LICENSE_1_0.txt or copy at -// http://www.boost.org/LICENSE_1_0.txt -// -// See http://kylelutz.github.com/compute for more information. -//---------------------------------------------------------------------------// - -#ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_FIXED_SORT_HPP -#define BOOST_COMPUTE_ALGORITHM_DETAIL_FIXED_SORT_HPP - -#include -#include -#include -#include -#include - -namespace boost { -namespace compute { -namespace detail { - -// sort two values -template -inline void sort2(const buffer &buffer, command_queue &queue) -{ - const context &context = queue.get_context(); - - boost::shared_ptr cache = - detail::get_program_cache(context); - std::string cache_key = - std::string("fixed_sort2_") + type_name(); - - program sort2_program = cache->get(cache_key); - if(!sort2_program.get()){ - const char source[] = - "__kernel void sort2(__global T *input)\n" - "{\n" - " const T x = input[0];\n" - " const T y = input[1];\n" - " if(y < x){\n" - " input[0] = y;\n" - " input[1] = x;\n" - " }\n" - "}\n"; - - sort2_program = program::build_with_source( - source, context, std::string("-DT=") + type_name() - ); - - cache->insert(cache_key, sort2_program); - } - - kernel sort2_kernel = sort2_program.create_kernel("sort2"); - sort2_kernel.set_arg(0, buffer); - queue.enqueue_task(sort2_kernel); -} - -// sort three values -template -inline void sort3(const buffer &buffer, command_queue &queue) -{ - const context &context = queue.get_context(); - - boost::shared_ptr cache = - detail::get_program_cache(context); - std::string cache_key = - std::string("fixed_sort3_") + type_name(); - - program sort3_program = cache->get(cache_key); - if(!sort3_program.get()){ - const char source[] = - "__kernel void sort3(__global T *input)\n" - "{\n" - " const T x = input[0];\n" - " const T y = input[1];\n" - " const T z = input[2];\n" - " if(y < x){\n" - " if(z < x){\n" - " if(z < y){\n" - " input[0] = z;\n" - " input[1] = y;\n" - " input[2] = x;\n" - " }\n" - " else {\n" - " input[0] = y;\n" - " input[1] = z;\n" - " input[2] = x;\n" - " }\n" - " }\n" - " else {\n" - " input[0] = y;\n" - " input[1] = x;\n" - " }\n" - " }\n" - " else {\n" - " if(z < x){\n" - " input[0] = z;\n" - " input[1] = x;\n" - " input[2] = y;\n" - " }\n" - " else if(z < y){\n" - " input[1] = z;\n" - " input[2] = y;\n" - " }\n" - " }\n" - "}\n"; - - sort3_program = program::build_with_source( - source, context, std::string("-DT=") + type_name() - ); - - cache->insert(cache_key, sort3_program); - } - - kernel sort3_kernel = sort3_program.create_kernel("sort3"); - sort3_kernel.set_arg(0, buffer); - queue.enqueue_task(sort3_kernel); -} - -} // end detail namespace -} // end compute namespace -} // end boost namespace - -#endif // BOOST_COMPUTE_ALGORITHM_DETAIL_FIXED_SORT_HPP diff --git a/include/boost/compute/algorithm/sort.hpp b/include/boost/compute/algorithm/sort.hpp index 257ec983..9f2c3794 100644 --- a/include/boost/compute/algorithm/sort.hpp +++ b/include/boost/compute/algorithm/sort.hpp @@ -15,40 +15,34 @@ #include -#include #include #include -#include #include #include +#include #include #include +#include namespace boost { namespace compute { namespace detail { -// sort() for device iterators -template -inline void dispatch_sort(Iterator first, - Iterator last, - command_queue &queue, - typename boost::enable_if< - is_device_iterator - >::type* = 0) +template +inline void dispatch_device_sort(buffer_iterator first, + buffer_iterator last, + less, + command_queue &queue, + typename boost::enable_if_c< + is_radix_sortable::value + >::type* = 0) { - typedef typename std::iterator_traits::value_type T; - size_t count = detail::iterator_range_size(first, last); + if(count < 2){ + // nothing to do return; } - else if(count == 2){ - ::boost::compute::detail::sort2(first.get_buffer(), queue); - } - else if(count == 3){ - ::boost::compute::detail::sort3(first.get_buffer(), queue); - } else if(count <= 32){ ::boost::compute::detail::serial_insertion_sort(first, last, queue); } @@ -57,10 +51,64 @@ inline void dispatch_sort(Iterator first, } } -// sort() for host iterators -template +template +inline void dispatch_device_sort(buffer_iterator first, + buffer_iterator last, + greater compare, + command_queue &queue, + typename boost::enable_if_c< + is_radix_sortable::value + >::type* = 0) +{ + size_t count = detail::iterator_range_size(first, last); + + if(count < 2){ + // nothing to do + return; + } + else if(count <= 32){ + ::boost::compute::detail::serial_insertion_sort( + first, last, compare, queue + ); + } + else { + // radix sort in ascending order + ::boost::compute::detail::radix_sort(first, last, queue); + + // reverse range to descending order + ::boost::compute::reverse(first, last, queue); + } +} + +template +inline void dispatch_device_sort(Iterator first, + Iterator last, + Compare compare, + command_queue &queue) +{ + ::boost::compute::detail::serial_insertion_sort( + first, last, compare, queue + ); +} + +// sort() for device iterators +template inline void dispatch_sort(Iterator first, Iterator last, + Compare compare, + command_queue &queue, + typename boost::enable_if< + is_device_iterator + >::type* = 0) +{ + dispatch_device_sort(first, last, compare, queue); +} + +// sort() for host iterators +template +inline void dispatch_sort(Iterator first, + Iterator last, + Compare compare, command_queue &queue, typename boost::disable_if< is_device_iterator @@ -76,7 +124,7 @@ inline void dispatch_sort(Iterator first, ); // sort mapped buffer - dispatch_sort(view.begin(), view.end(), queue); + dispatch_device_sort(view.begin(), view.end(), compare, queue); // return results to host view.map(queue); @@ -118,15 +166,7 @@ inline void sort(Iterator first, Compare compare, command_queue &queue = system::default_queue()) { - size_t count = detail::iterator_range_size(first, last); - if(count < 2){ - return; - } - - return ::boost::compute::detail::serial_insertion_sort(first, - last, - compare, - queue); + ::boost::compute::detail::dispatch_sort(first, last, compare, queue); } /// \overload @@ -135,7 +175,11 @@ inline void sort(Iterator first, Iterator last, command_queue &queue = system::default_queue()) { - detail::dispatch_sort(first, last, queue); + typedef typename std::iterator_traits::value_type value_type; + + ::boost::compute::sort( + first, last, ::boost::compute::less(), queue + ); } } // end compute namespace