2
0
mirror of https://github.com/boostorg/compute.git synced 2026-01-27 06:42:19 +00:00

Merge pull request #322 from roshanr95/merge

Merge
This commit is contained in:
Kyle Lutz
2014-12-02 19:27:49 -08:00
4 changed files with 65 additions and 31 deletions

View File

@@ -69,8 +69,8 @@ public:
"{\n" <<
" a_index = (start + end)/2;\n" <<
" b_index = target - a_index - 1;\n" <<
" if(!" << comp(first2[expr<uint_>("b_index")],
first1[expr<uint_>("a_index")]) << ")\n" <<
" if(!(" << comp(first2[expr<uint_>("b_index")],
first1[expr<uint_>("a_index")]) << "))\n" <<
" start = a_index + 1;\n" <<
" else end = a_index;\n" <<
"}\n" <<

View File

@@ -41,13 +41,15 @@ public:
}
template<class InputIterator1, class InputIterator2,
class OutputIterator1, class OutputIterator2>
class OutputIterator1, class OutputIterator2,
class Compare>
void set_range(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
InputIterator2 last2,
OutputIterator1 result_a,
OutputIterator2 result_b)
OutputIterator2 result_b,
Compare comp)
{
m_a_count = iterator_range_size(first1, last1);
m_a_count_arg = add_arg<uint_>("a_count");
@@ -65,14 +67,27 @@ public:
"{\n" <<
" a_index = (start + end)/2;\n" <<
" b_index = target - a_index - 1;\n" <<
" if(" << first1[expr<uint_>("a_index")] <<
" <=" << first2[expr<uint_>("b_index")] << ")\n" <<
" if(!(" << comp(first2[expr<uint_>("b_index")],
first1[expr<uint_>("a_index")]) << "))\n" <<
" start = a_index + 1;\n" <<
" else end = a_index;\n" <<
"}\n" <<
result_a[expr<uint_>("i")] << " = start;\n" <<
result_b[expr<uint_>("i")] << " = target - start;\n";
}
template<class InputIterator1, class InputIterator2,
class OutputIterator1, class OutputIterator2>
void set_range(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
InputIterator2 last2,
OutputIterator1 result_a,
OutputIterator2 result_b)
{
typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
::boost::compute::less<value_type> less_than;
set_range(first1, last1, first2, last2, result_a, result_b, less_than);
}
event exec(command_queue &queue)

View File

@@ -41,13 +41,14 @@ public:
template<class InputIterator1, class InputIterator2,
class InputIterator3, class InputIterator4,
class OutputIterator>
class OutputIterator, class Compare>
void set_range(InputIterator1 first1,
InputIterator2 first2,
InputIterator3 tile_first1,
InputIterator3 tile_last1,
InputIterator4 tile_first2,
OutputIterator result)
OutputIterator result,
Compare comp)
{
m_count = iterator_range_size(tile_first1, tile_last1) - 1;
@@ -60,8 +61,8 @@ public:
"uint index = i*" << tile_size << ";\n" <<
"while(start1<end1 && start2<end2)\n" <<
"{\n" <<
" if(" << first1[expr<uint_>("start1")] << " <= " <<
first2[expr<uint_>("start2")] << ")\n" <<
" if(!(" << comp(first2[expr<uint_>("start2")],
first1[expr<uint_>("start1")]) << "))\n" <<
" {\n" <<
result[expr<uint_>("index")] <<
" = " << first1[expr<uint_>("start1")] << ";\n" <<
@@ -92,6 +93,21 @@ public:
"}\n";
}
template<class InputIterator1, class InputIterator2,
class InputIterator3, class InputIterator4,
class OutputIterator>
void set_range(InputIterator1 first1,
InputIterator2 first2,
InputIterator3 tile_first1,
InputIterator3 tile_last1,
InputIterator4 tile_first2,
OutputIterator result)
{
typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
::boost::compute::less<value_type> less_than;
set_range(first1, first2, tile_first1, tile_last1, tile_first2, result, less_than);
}
event exec(command_queue &queue)
{
if(m_count == 0) {
@@ -118,15 +134,17 @@ private:
/// \param last2 Iterator pointing to end of second set
/// \param result Iterator pointing to start of range in which the result
/// will be stored
/// \param comp Comparator which performs less than function
/// \param queue Queue on which to execute
///
template<class InputIterator1, class InputIterator2, class OutputIterator>
template<class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
inline OutputIterator
merge_with_merge_path(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
InputIterator2 last2,
OutputIterator result,
Compare comp,
command_queue &queue = system::default_queue())
{
int tile_size = 1024;
@@ -141,7 +159,7 @@ merge_with_merge_path(InputIterator1 first1,
merge_path_kernel tiling_kernel;
tiling_kernel.tile_size = 1024;
tiling_kernel.set_range(first1, last1, first2, last2,
tile_a.begin()+1, tile_b.begin()+1);
tile_a.begin()+1, tile_b.begin()+1, comp);
fill_n(tile_a.begin(), 1, 0, queue);
fill_n(tile_b.begin(), 1, 0, queue);
tiling_kernel.exec(queue);
@@ -153,13 +171,28 @@ merge_with_merge_path(InputIterator1 first1,
serial_merge_kernel merge_kernel;
merge_kernel.tile_size = 1024;
merge_kernel.set_range(first1, first2, tile_a.begin(), tile_a.end(),
tile_b.begin(), result);
tile_b.begin(), result, comp);
merge_kernel.exec(queue);
return result + count1 + count2;
}
/// \overload
template<class InputIterator1, class InputIterator2, class OutputIterator>
inline OutputIterator
merge_with_merge_path(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
InputIterator2 last2,
OutputIterator result,
command_queue &queue = system::default_queue())
{
typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
::boost::compute::less<value_type> less_than;
return merge_with_merge_path(first1, last1, first2, last2, result, less_than, queue);
}
} //end detail namespace
} //end compute namespace
} //end boost namespace

View File

@@ -48,23 +48,7 @@ inline OutputIterator merge(InputIterator1 first1,
Compare comp,
command_queue &queue = system::default_queue())
{
size_t size1 = detail::iterator_range_size(first1, last1);
size_t size2 = detail::iterator_range_size(first2, last2);
// handle trivial cases
if(size1 == 0 && size2 == 0){
return result;
}
else if(size1 == 0){
return ::boost::compute::copy(first2, last2, result, queue);
}
else if(size2 == 0){
return ::boost::compute::copy(first1, last1, result, queue);
}
return detail::serial_merge(
first1, last1, first2, last2, result, comp, queue
);
return detail::merge_with_merge_path(first1, last1, first2, last2, result, comp, queue);
}
/// \overload
@@ -76,7 +60,9 @@ inline OutputIterator merge(InputIterator1 first1,
OutputIterator result,
command_queue &queue = system::default_queue())
{
return detail::merge_with_merge_path(first1, last1, first2, last2, result, queue);
typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
less<value_type> less_than;
return merge(first1, last1, first2, last2, result, less_than, queue);
}
} // end compute namespace