diff --git a/include/boost/compute/algorithm/gather.hpp b/include/boost/compute/algorithm/gather.hpp index 6fc04241..04017883 100644 --- a/include/boost/compute/algorithm/gather.hpp +++ b/include/boost/compute/algorithm/gather.hpp @@ -48,15 +48,15 @@ struct gather_kernel } // end detail namespace -/// Copies the elements from the range [\p first, \p last) to the range -/// beginning at \p result using the input indices from the range beginning -/// at \p map. +/// Copies the elements using the indices from the range [\p first, \p last) +/// to the range beginning at \p result using the input values from the range +/// beginning at \p input. /// /// \see scatter() template -inline void gather(InputIterator first, - InputIterator last, - MapIterator map, +inline void gather(MapIterator first, + MapIterator last, + InputIterator input, OutputIterator result, command_queue &queue = system::default_queue()) { @@ -71,8 +71,8 @@ inline void gather(InputIterator first, output_value_type>::source(); kernel kernel = kernel::create_with_source(source, "gather", context); - kernel.set_arg(0, first.get_buffer()); - kernel.set_arg(1, map.get_buffer()); + kernel.set_arg(0, input.get_buffer()); + kernel.set_arg(1, first.get_buffer()); kernel.set_arg(2, result.get_buffer()); size_t offset = first.get_index(); diff --git a/test/test_gather.cpp b/test/test_gather.cpp index ec14cf6c..6ec8c986 100644 --- a/test/test_gather.cpp +++ b/test/test_gather.cpp @@ -12,25 +12,58 @@ #include #include +#include #include #include #include "check_macros.hpp" #include "context_setup.hpp" -namespace bc = boost::compute; +namespace compute = boost::compute; BOOST_AUTO_TEST_CASE(gather_int) { int input_data[] = { 1, 2, 3, 4, 5 }; - bc::vector input(input_data, input_data + 5, context); + compute::vector input(5, context); + compute::copy_n(input_data, 5, input.begin(), queue); - int map_data[] = { 0, 4, 1, 3, 2 }; - bc::vector map(map_data, map_data + 5, context); + int indices_data[] = { 0, 4, 1, 3, 2 }; + compute::vector indices(5, context); + compute::copy_n(indices_data, 5, indices.begin(), queue); - bc::vector output(5, context); - bc::gather(input.begin(), input.end(), map.begin(), output.begin()); + compute::vector output(5, context); + compute::gather( + indices.begin(), indices.end(), input.begin(), output.begin(), queue + ); CHECK_RANGE_EQUAL(int, 5, output, (1, 5, 2, 4, 3)); } +BOOST_AUTO_TEST_CASE(copy_index_then_gather) +{ + // input data + int data[] = { 1, 4, 3, 2, 5, 9, 8, 7 }; + compute::vector input(8, context); + compute::copy_n(data, 8, input.begin(), queue); + + // function returning true if the input is odd + BOOST_COMPUTE_FUNCTION(bool, is_odd, (int), + { + return _1 % 2 != 0; + }); + + // copy indices of all odd values + compute::vector odds(5, context); + compute::detail::copy_index_if( + input.begin(), input.end(), odds.begin(), is_odd, queue + ); + CHECK_RANGE_EQUAL(int, 5, odds, (0, 2, 4, 5, 7)); + + // gather all odd values + compute::vector odd_values(5, context); + compute::gather( + odds.begin(), odds.end(), input.begin(), odd_values.begin(), queue + ); + CHECK_RANGE_EQUAL(int, 5, odd_values, (1, 3, 5, 9, 7)); +} + BOOST_AUTO_TEST_SUITE_END()