From b848b44c40602a23cd283e4a2b9dd8e7505a77a6 Mon Sep 17 00:00:00 2001 From: Roshan Date: Tue, 2 Dec 2014 12:58:55 +0530 Subject: [PATCH 1/2] Fix bug in balanced path Fixed a bug with precedence when the comparison function is not enclosed in parentheses --- include/boost/compute/algorithm/detail/balanced_path.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/boost/compute/algorithm/detail/balanced_path.hpp b/include/boost/compute/algorithm/detail/balanced_path.hpp index 7441d8dc..0b86b4bf 100644 --- a/include/boost/compute/algorithm/detail/balanced_path.hpp +++ b/include/boost/compute/algorithm/detail/balanced_path.hpp @@ -69,8 +69,8 @@ public: "{\n" << " a_index = (start + end)/2;\n" << " b_index = target - a_index - 1;\n" << - " if(!" << comp(first2[expr("b_index")], - first1[expr("a_index")]) << ")\n" << + " if(!(" << comp(first2[expr("b_index")], + first1[expr("a_index")]) << "))\n" << " start = a_index + 1;\n" << " else end = a_index;\n" << "}\n" << From 12bbd995a3437a56c95338657b022fca4ec920e2 Mon Sep 17 00:00:00 2001 From: Roshan Date: Tue, 2 Dec 2014 12:59:36 +0530 Subject: [PATCH 2/2] Add support for comparator in merge --- .../compute/algorithm/detail/merge_path.hpp | 23 +++++++-- .../detail/merge_with_merge_path.hpp | 47 ++++++++++++++++--- include/boost/compute/algorithm/merge.hpp | 22 ++------- 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/include/boost/compute/algorithm/detail/merge_path.hpp b/include/boost/compute/algorithm/detail/merge_path.hpp index 91f2a8db..6747e89d 100644 --- a/include/boost/compute/algorithm/detail/merge_path.hpp +++ b/include/boost/compute/algorithm/detail/merge_path.hpp @@ -41,13 +41,15 @@ public: } template + class OutputIterator1, class OutputIterator2, + class Compare> void set_range(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator1 result_a, - OutputIterator2 result_b) + OutputIterator2 result_b, + Compare comp) { m_a_count = iterator_range_size(first1, last1); m_a_count_arg = add_arg("a_count"); @@ -65,14 +67,27 @@ public: "{\n" << " a_index = (start + end)/2;\n" << " b_index = target - a_index - 1;\n" << - " if(" << first1[expr("a_index")] << - " <=" << first2[expr("b_index")] << ")\n" << + " if(!(" << comp(first2[expr("b_index")], + first1[expr("a_index")]) << "))\n" << " start = a_index + 1;\n" << " else end = a_index;\n" << "}\n" << result_a[expr("i")] << " = start;\n" << result_b[expr("i")] << " = target - start;\n"; + } + template + void set_range(InputIterator1 first1, + InputIterator1 last1, + InputIterator2 first2, + InputIterator2 last2, + OutputIterator1 result_a, + OutputIterator2 result_b) + { + typedef typename std::iterator_traits::value_type value_type; + ::boost::compute::less less_than; + set_range(first1, last1, first2, last2, result_a, result_b, less_than); } event exec(command_queue &queue) diff --git a/include/boost/compute/algorithm/detail/merge_with_merge_path.hpp b/include/boost/compute/algorithm/detail/merge_with_merge_path.hpp index 73b6b6a1..b34aeb3a 100644 --- a/include/boost/compute/algorithm/detail/merge_with_merge_path.hpp +++ b/include/boost/compute/algorithm/detail/merge_with_merge_path.hpp @@ -41,13 +41,14 @@ public: template + class OutputIterator, class Compare> void set_range(InputIterator1 first1, InputIterator2 first2, InputIterator3 tile_first1, InputIterator3 tile_last1, InputIterator4 tile_first2, - OutputIterator result) + OutputIterator result, + Compare comp) { m_count = iterator_range_size(tile_first1, tile_last1) - 1; @@ -60,8 +61,8 @@ public: "uint index = i*" << tile_size << ";\n" << "while(start1("start1")] << " <= " << - first2[expr("start2")] << ")\n" << + " if(!(" << comp(first2[expr("start2")], + first1[expr("start1")]) << "))\n" << " {\n" << result[expr("index")] << " = " << first1[expr("start1")] << ";\n" << @@ -92,6 +93,21 @@ public: "}\n"; } + template + void set_range(InputIterator1 first1, + InputIterator2 first2, + InputIterator3 tile_first1, + InputIterator3 tile_last1, + InputIterator4 tile_first2, + OutputIterator result) + { + typedef typename std::iterator_traits::value_type value_type; + ::boost::compute::less less_than; + set_range(first1, first2, tile_first1, tile_last1, tile_first2, result, less_than); + } + event exec(command_queue &queue) { if(m_count == 0) { @@ -118,15 +134,17 @@ private: /// \param last2 Iterator pointing to end of second set /// \param result Iterator pointing to start of range in which the result /// will be stored +/// \param comp Comparator which performs less than function /// \param queue Queue on which to execute /// -template +template inline OutputIterator merge_with_merge_path(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, + Compare comp, command_queue &queue = system::default_queue()) { int tile_size = 1024; @@ -141,7 +159,7 @@ merge_with_merge_path(InputIterator1 first1, merge_path_kernel tiling_kernel; tiling_kernel.tile_size = 1024; tiling_kernel.set_range(first1, last1, first2, last2, - tile_a.begin()+1, tile_b.begin()+1); + tile_a.begin()+1, tile_b.begin()+1, comp); fill_n(tile_a.begin(), 1, 0, queue); fill_n(tile_b.begin(), 1, 0, queue); tiling_kernel.exec(queue); @@ -153,13 +171,28 @@ merge_with_merge_path(InputIterator1 first1, serial_merge_kernel merge_kernel; merge_kernel.tile_size = 1024; merge_kernel.set_range(first1, first2, tile_a.begin(), tile_a.end(), - tile_b.begin(), result); + tile_b.begin(), result, comp); merge_kernel.exec(queue); return result + count1 + count2; } +/// \overload +template +inline OutputIterator +merge_with_merge_path(InputIterator1 first1, + InputIterator1 last1, + InputIterator2 first2, + InputIterator2 last2, + OutputIterator result, + command_queue &queue = system::default_queue()) +{ + typedef typename std::iterator_traits::value_type value_type; + ::boost::compute::less less_than; + return merge_with_merge_path(first1, last1, first2, last2, result, less_than, queue); +} + } //end detail namespace } //end compute namespace } //end boost namespace diff --git a/include/boost/compute/algorithm/merge.hpp b/include/boost/compute/algorithm/merge.hpp index 5ed88fdf..25c7218b 100644 --- a/include/boost/compute/algorithm/merge.hpp +++ b/include/boost/compute/algorithm/merge.hpp @@ -48,23 +48,7 @@ inline OutputIterator merge(InputIterator1 first1, Compare comp, command_queue &queue = system::default_queue()) { - size_t size1 = detail::iterator_range_size(first1, last1); - size_t size2 = detail::iterator_range_size(first2, last2); - - // handle trivial cases - if(size1 == 0 && size2 == 0){ - return result; - } - else if(size1 == 0){ - return ::boost::compute::copy(first2, last2, result, queue); - } - else if(size2 == 0){ - return ::boost::compute::copy(first1, last1, result, queue); - } - - return detail::serial_merge( - first1, last1, first2, last2, result, comp, queue - ); + return detail::merge_with_merge_path(first1, last1, first2, last2, result, comp, queue); } /// \overload @@ -76,7 +60,9 @@ inline OutputIterator merge(InputIterator1 first1, OutputIterator result, command_queue &queue = system::default_queue()) { - return detail::merge_with_merge_path(first1, last1, first2, last2, result, queue); + typedef typename std::iterator_traits::value_type value_type; + less less_than; + return merge(first1, last1, first2, last2, result, less_than, queue); } } // end compute namespace