From ba1416fff0a20b21e09e28f8fcd399ea0d7bd1e6 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 6 Oct 2010 19:05:20 +0000 Subject: [PATCH] boost.python.numpy - moved dtype::invoke_matching_template into separate header, added similar code for invocation based on dimensionality --- boost/python/numpy/dtype.hpp | 68 --------- boost/python/numpy/invoke_matching.hpp | 175 +++++++++++++++++++++++ boost/python/numpy/ndarray.hpp | 1 - boost/python/numpy/numpy.hpp | 1 + libs/python/numpy/test/templates.py | 12 +- libs/python/numpy/test/templates_mod.cpp | 45 ++++-- 6 files changed, 220 insertions(+), 82 deletions(-) create mode 100644 boost/python/numpy/invoke_matching.hpp diff --git a/boost/python/numpy/dtype.hpp b/boost/python/numpy/dtype.hpp index edbde0e0..35f39bba 100644 --- a/boost/python/numpy/dtype.hpp +++ b/boost/python/numpy/dtype.hpp @@ -46,76 +46,8 @@ public: BOOST_PYTHON_FORWARD_OBJECT_CONSTRUCTORS(dtype, object); - template - void invoke_matching_template(Function f) const; - }; -namespace detail { - -struct add_pointer_meta { - - template - struct apply { - typedef typename boost::add_pointer::type type; - }; - -}; - -struct dtype_template_match_found {}; - -template -struct dtype_template_invoker { - - template - void operator()(T * x) const { - if (dtype::get_builtin() == m_dtype) { - m_func.template apply(); - throw dtype_template_match_found(); - } - } - - dtype_template_invoker(dtype const & dtype_, Function func) : - m_dtype(dtype_), m_func(func) {} - -private: - dtype const & m_dtype; - Function m_func; -}; - -template -struct dtype_template_invoker< boost::reference_wrapper > { - - template - void operator()(T * x) const { - if (dtype::get_builtin() == m_dtype) { - m_func.template apply(); - throw dtype_template_match_found(); - } - } - - dtype_template_invoker(dtype const & dtype_, Function & func) : - m_dtype(dtype_), m_func(func) {} - -private: - dtype const & m_dtype; - Function & m_func; -}; - -} // namespace boost::python::numpy::detail - -template -void dtype::invoke_matching_template(Function f) const { - detail::dtype_template_invoker invoker(*this, f); - try { - boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker); - } catch (detail::dtype_template_match_found &) { - return; - } - PyErr_SetString(PyExc_TypeError, "numpy.dtype not found in template list."); - throw_error_already_set(); -} - } // namespace boost::python::numpy namespace converter { diff --git a/boost/python/numpy/invoke_matching.hpp b/boost/python/numpy/invoke_matching.hpp new file mode 100644 index 00000000..546072f9 --- /dev/null +++ b/boost/python/numpy/invoke_matching.hpp @@ -0,0 +1,175 @@ +#ifndef BOOST_PYTHON_NUMPY_INVOKE_MATCHING_HPP_INCLUDED +#define BOOST_PYTHON_NUMPY_INVOKE_MATCHING_HPP_INCLUDED + +/** + * @file boost/python/numpy/ndarray.hpp + * @brief Object manager and various utilities for numpy.ndarray. + */ + +#include +#include + +namespace boost { namespace python { namespace numpy { + +namespace detail { + +struct add_pointer_meta { + + template + struct apply { + typedef typename boost::add_pointer::type type; + }; + +}; + +struct dtype_template_match_found {}; +struct nd_template_match_found {}; + +template +struct dtype_template_invoker { + + template + void operator()(T * x) const { + if (dtype::get_builtin() == m_dtype) { + m_func.apply(x); + throw dtype_template_match_found(); + } + } + + dtype_template_invoker(dtype const & dtype_, Function func) : + m_dtype(dtype_), m_func(func) {} + +private: + dtype const & m_dtype; + Function m_func; +}; + +template +struct dtype_template_invoker< boost::reference_wrapper > { + + template + void operator()(T * x) const { + if (dtype::get_builtin() == m_dtype) { + m_func.apply(x); + throw dtype_template_match_found(); + } + } + + dtype_template_invoker(dtype const & dtype_, Function & func) : + m_dtype(dtype_), m_func(func) {} + +private: + dtype const & m_dtype; + Function & m_func; +}; + +template +struct nd_template_invoker { + + template + void operator()(T * x) const { + if (m_nd == T::value) { + m_func.apply(x); + throw nd_template_match_found(); + } + } + + nd_template_invoker(int nd, Function func) : + m_nd(nd), m_func(func) {} + +private: + int m_nd; + Function m_func; +}; + +template +struct nd_template_invoker< boost::reference_wrapper > { + + template + void operator()(T * x) const { + if (m_nd == T::value) { + m_func.apply(x); + throw nd_template_match_found(); + } + } + + nd_template_invoker(int nd, Function & func) : + m_nd(nd), m_func(func) {} + +private: + int m_nd; + Function & m_func; +}; + +} // namespace boost::python::numpy::detail + +template +void invoke_matching_nd(int nd, Function f) { + detail::nd_template_invoker invoker(nd, f); + try { + boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker); + } catch (detail::nd_template_match_found &) { + return; + } + PyErr_SetString(PyExc_TypeError, "number of dimensions not found in template list."); + throw_error_already_set(); +} + +template +void invoke_matching_dtype(dtype const & dtype_, Function f) { + detail::dtype_template_invoker invoker(dtype_, f); + try { + boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker); + } catch (detail::dtype_template_match_found &) { + return; + } + PyErr_SetString(PyExc_TypeError, "dtype not found in template list."); + throw_error_already_set(); +} + +namespace detail { + +template +struct array_template_invoker_wrapper { + + template + void apply(T * x) const { + invoke_matching_nd(m_nd, m_func.nest(x)); + } + + array_template_invoker_wrapper(int nd, Function func) : + m_nd(nd), m_func(func) {} + +private: + int m_nd; + Function m_func; +}; + +template +struct array_template_invoker_wrapper< DimSequence, boost::reference_wrapper > { + + template + void apply(T * x) const { + invoke_matching_nd(m_nd, m_func.nest(x)); + } + + array_template_invoker_wrapper(int nd, Function & func) : + m_nd(nd), m_func(func) {} + +private: + int m_nd; + Function & m_func; +}; + +} // namespace boost::python::numpy::detail + +template +void invoke_matching_array(ndarray const & array_, Function f) { + detail::array_template_invoker_wrapper wrapper(array_.get_nd(), f); + invoke_matching_dtype(array_.get_dtype(), wrapper); +} + + +}}} // namespace boost::python::numpy + +#endif // !BOOST_PYTHON_NUMPY_INVOKE_MATCHING_HPP_INCLUDED diff --git a/boost/python/numpy/ndarray.hpp b/boost/python/numpy/ndarray.hpp index 6ca0f499..0b53cc1f 100644 --- a/boost/python/numpy/ndarray.hpp +++ b/boost/python/numpy/ndarray.hpp @@ -275,7 +275,6 @@ inline ndarray::bitflag operator&(ndarray::bitflag a, ndarray::bitflag b) { return ndarray::bitflag(int(a) & int(b)); } - } // namespace boost::python::numpy namespace converter { diff --git a/boost/python/numpy/numpy.hpp b/boost/python/numpy/numpy.hpp index 9e0f6a44..c40859d8 100644 --- a/boost/python/numpy/numpy.hpp +++ b/boost/python/numpy/numpy.hpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace boost { namespace python { namespace numpy { diff --git a/libs/python/numpy/test/templates.py b/libs/python/numpy/test/templates.py index c2967c1c..e848141f 100755 --- a/libs/python/numpy/test/templates.py +++ b/libs/python/numpy/test/templates.py @@ -6,12 +6,16 @@ class TestTemplates(unittest.TestCase): def testTemplates(self): for dtype in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): - a1 = numpy.zeros((12,), dtype=dtype) - a2 = numpy.arange(12, dtype=dtype) - templates_mod.fill(a1) - self.assert_((a1 == a2).all()) + v = numpy.arange(12, dtype=dtype) + for shape in ((12,), (4, 3), (2, 6)): + a1 = numpy.zeros(shape, dtype=dtype) + a2 = v.reshape(a1.shape) + templates_mod.fill(a1) + self.assert_((a1 == a2).all()) a1 = numpy.zeros((12,), dtype=numpy.float64) self.assertRaises(TypeError, templates_mod.fill, a1) + a1 = numpy.zeros((12,2,3), dtype=numpy.float32) + self.assertRaises(TypeError, templates_mod.fill, a1) if __name__=="__main__": unittest.main() diff --git a/libs/python/numpy/test/templates_mod.cpp b/libs/python/numpy/test/templates_mod.cpp index 63e69f4d..82c5511e 100644 --- a/libs/python/numpy/test/templates_mod.cpp +++ b/libs/python/numpy/test/templates_mod.cpp @@ -1,21 +1,48 @@ #include #include +#include namespace bp = boost::python; struct ArrayFiller { - typedef boost::mpl::vector< short, int, float, std::complex > Sequence; + typedef boost::mpl::vector< short, int, float, std::complex > TypeSequence; + typedef boost::mpl::vector_c< int, 1, 2 > DimSequence; template - void apply() const { - char * p = argument.get_data(); - int stride = argument.strides(0); - int size = argument.shape(0); - for (int n = 0; n != size; ++n, p += stride) { - *reinterpret_cast(p) = static_cast(n); + struct nested { + + void apply(boost::mpl::integral_c * ) const { + char * p = argument.get_data(); + int stride = argument.strides(0); + int size = argument.shape(0); + for (int n = 0; n != size; ++n, p += stride) { + *reinterpret_cast(p) = static_cast(n); + } } - } + + void apply(boost::mpl::integral_c * ) const { + char * row_p = argument.get_data(); + int row_stride = argument.strides(0); + int col_stride = argument.strides(1); + int rows = argument.shape(0); + int cols = argument.shape(1); + int i = 0; + for (int n = 0; n != rows; ++n, row_p += row_stride) { + char * col_p = row_p; + for (int m = 0; m != cols; ++i, ++m, col_p += col_stride) { + *reinterpret_cast(col_p) = static_cast(i); + } + } + } + + explicit nested(bp::numpy::ndarray const & arg) : argument(arg) {} + + bp::numpy::ndarray argument; + }; + + template + nested nest(T *) const { return nested(argument); } bp::numpy::ndarray argument; @@ -25,7 +52,7 @@ struct ArrayFiller { void fill(bp::numpy::ndarray const & arg) { ArrayFiller filler(arg); - arg.get_dtype().invoke_matching_template< ArrayFiller::Sequence >(filler); + bp::numpy::invoke_matching_array< ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler); } BOOST_PYTHON_MODULE(templates_mod) {