diff --git a/boost/python/numpy/invoke_matching.hpp b/boost/python/numpy/invoke_matching.hpp index 546072f9..4d1a7a0f 100644 --- a/boost/python/numpy/invoke_matching.hpp +++ b/boost/python/numpy/invoke_matching.hpp @@ -9,6 +9,8 @@ #include #include +#include + namespace boost { namespace python { namespace numpy { namespace detail { @@ -29,9 +31,9 @@ template struct dtype_template_invoker { template - void operator()(T * x) const { + void operator()(T *) const { if (dtype::get_builtin() == m_dtype) { - m_func.apply(x); + m_func.template apply(); throw dtype_template_match_found(); } } @@ -48,9 +50,9 @@ template struct dtype_template_invoker< boost::reference_wrapper > { template - void operator()(T * x) const { + void operator()(T *) const { if (dtype::get_builtin() == m_dtype) { - m_func.apply(x); + m_func.template apply(); throw dtype_template_match_found(); } } @@ -66,10 +68,10 @@ private: template struct nd_template_invoker { - template - void operator()(T * x) const { - if (m_nd == T::value) { - m_func.apply(x); + template + void operator()(boost::mpl::integral_c *) const { + if (m_nd == N) { + m_func.template apply(); throw nd_template_match_found(); } } @@ -85,10 +87,10 @@ private: template struct nd_template_invoker< boost::reference_wrapper > { - template - void operator()(T * x) const { - if (m_nd == T::value) { - m_func.apply(x); + template + void operator()(boost::mpl::integral_c *) const { + if (m_nd == N) { + m_func.template apply(); throw nd_template_match_found(); } } @@ -129,31 +131,30 @@ void invoke_matching_dtype(dtype const & dtype_, Function f) { namespace detail { -template -struct array_template_invoker_wrapper { +template +struct array_template_invoker_wrapper_2 { - template - void apply(T * x) const { - invoke_matching_nd(m_nd, m_func.nest(x)); + template + void apply() const { + m_func.template apply(); } - array_template_invoker_wrapper(int nd, Function func) : - m_nd(nd), m_func(func) {} + array_template_invoker_wrapper_2(Function & func) : + m_func(func) {} private: - int m_nd; - Function m_func; + Function & m_func; }; template -struct array_template_invoker_wrapper< DimSequence, boost::reference_wrapper > { +struct array_template_invoker_wrapper_1 { template - void apply(T * x) const { - invoke_matching_nd(m_nd, m_func.nest(x)); + void apply() const { + invoke_matching_nd(m_nd, array_template_invoker_wrapper_2(m_func)); } - array_template_invoker_wrapper(int nd, Function & func) : + array_template_invoker_wrapper_1(int nd, Function & func) : m_nd(nd), m_func(func) {} private: @@ -161,11 +162,19 @@ private: Function & m_func; }; +template +struct array_template_invoker_wrapper_1< DimSequence, boost::reference_wrapper > + : public array_template_invoker_wrapper_1< DimSequence, Function > +{ + array_template_invoker_wrapper_1(int nd, Function & func) : + array_template_invoker_wrapper_1< DimSequence, Function >(nd, func) {} +}; + } // namespace boost::python::numpy::detail template void invoke_matching_array(ndarray const & array_, Function f) { - detail::array_template_invoker_wrapper wrapper(array_.get_nd(), f); + detail::array_template_invoker_wrapper_1 wrapper(array_.get_nd(), f); invoke_matching_dtype(array_.get_dtype(), wrapper); } diff --git a/libs/python/numpy/test/templates_mod.cpp b/libs/python/numpy/test/templates_mod.cpp index 82c5511e..6da9f581 100644 --- a/libs/python/numpy/test/templates_mod.cpp +++ b/libs/python/numpy/test/templates_mod.cpp @@ -9,19 +9,16 @@ struct ArrayFiller { typedef boost::mpl::vector< short, int, float, std::complex > TypeSequence; typedef boost::mpl::vector_c< int, 1, 2 > DimSequence; - template - struct nested { - - void apply(boost::mpl::integral_c * ) const { + template + void apply() const { + if (N == 1) { char * p = argument.get_data(); int stride = argument.strides(0); int size = argument.shape(0); for (int n = 0; n != size; ++n, p += stride) { *reinterpret_cast(p) = static_cast(n); } - } - - void apply(boost::mpl::integral_c * ) const { + } else { char * row_p = argument.get_data(); int row_stride = argument.strides(0); int col_stride = argument.strides(1); @@ -35,14 +32,7 @@ struct ArrayFiller { } } } - - explicit nested(bp::numpy::ndarray const & arg) : argument(arg) {} - - bp::numpy::ndarray argument; - }; - - template - nested nest(T *) const { return nested(argument); } + } bp::numpy::ndarray argument;