diff --git a/boost/python/numpy/dtype.hpp b/boost/python/numpy/dtype.hpp index 53c61f34..cb1666b5 100644 --- a/boost/python/numpy/dtype.hpp +++ b/boost/python/numpy/dtype.hpp @@ -9,6 +9,9 @@ #include #include +#include +#include + namespace boost { namespace python { namespace numpy { @@ -43,8 +46,57 @@ public: BOOST_PYTHON_FORWARD_OBJECT_CONSTRUCTORS(dtype, object); + template + void invoke_matching_template(Function f); + }; +namespace detail { + +struct add_pointer_meta { + + template + struct apply { + typedef typename boost::add_pointer::type type; + }; + +}; + +struct dtype_template_match_found {}; + +template +struct dtype_template_invoker { + + template + void operator()(T * x) const { + if (dtype::get_builtin() == m_dtype) { + m_func.template apply(); + throw dtype_template_match_found(); + } + } + + dtype_template_invoker(dtype const & dtype_, Function func) : + m_dtype(dtype_), m_func(func) {} + +private: + dtype const & m_dtype; + Function m_func; +}; + +} // namespace boost::python::numpy::detail + +template +void dtype::invoke_matching_template(Function f) { + detail::dtype_template_invoker invoker(*this, f); + try { + boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker); + } catch (detail::dtype_template_match_found &) { + return; + } + PyErr_SetString(PyExc_TypeError, "numpy.dtype not found in template list."); + throw_error_already_set(); +} + } // namespace boost::python::numpy namespace converter { diff --git a/libs/python/numpy/test/SConscript b/libs/python/numpy/test/SConscript index 7dc0d24b..6e7d8d27 100644 --- a/libs/python/numpy/test/SConscript +++ b/libs/python/numpy/test/SConscript @@ -1,8 +1,12 @@ Import("bp_numpy_env") -ufunc_mod = bp_numpy_env.SharedLibrary("ufunc_mod", "ufunc_mod.cpp", SHLIBPREFIX="", - LIBS="boost_python_numpy") -ufunc_test = bp_numpy_env.PythonUnitTest("ufunc.py", ufunc_mod) +test = [] + +for name in ("ufunc", "templates"): + mod = bp_numpy_env.SharedLibrary("%s_mod" % name, "%s_mod.cpp" % name, SHLIBPREFIX="", + LIBS="boost_python_numpy") + test.extend( + bp_numpy_env.PythonUnitTest("%s.py" % name, mod) + ) -test = ufunc_test Return("test") diff --git a/libs/python/numpy/test/templates.py b/libs/python/numpy/test/templates.py new file mode 100755 index 00000000..c2967c1c --- /dev/null +++ b/libs/python/numpy/test/templates.py @@ -0,0 +1,17 @@ +import templates_mod +import unittest +import numpy + +class TestTemplates(unittest.TestCase): + + def testTemplates(self): + for dtype in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): + a1 = numpy.zeros((12,), dtype=dtype) + a2 = numpy.arange(12, dtype=dtype) + templates_mod.fill(a1) + self.assert_((a1 == a2).all()) + a1 = numpy.zeros((12,), dtype=numpy.float64) + self.assertRaises(TypeError, templates_mod.fill, a1) + +if __name__=="__main__": + unittest.main() diff --git a/libs/python/numpy/test/templates_mod.cpp b/libs/python/numpy/test/templates_mod.cpp new file mode 100644 index 00000000..63e69f4d --- /dev/null +++ b/libs/python/numpy/test/templates_mod.cpp @@ -0,0 +1,34 @@ +#include +#include + +namespace bp = boost::python; + +struct ArrayFiller { + + typedef boost::mpl::vector< short, int, float, std::complex > Sequence; + + template + void apply() const { + char * p = argument.get_data(); + int stride = argument.strides(0); + int size = argument.shape(0); + for (int n = 0; n != size; ++n, p += stride) { + *reinterpret_cast(p) = static_cast(n); + } + } + + bp::numpy::ndarray argument; + + explicit ArrayFiller(bp::numpy::ndarray const & arg) : argument(arg) {} + +}; + +void fill(bp::numpy::ndarray const & arg) { + ArrayFiller filler(arg); + arg.get_dtype().invoke_matching_template< ArrayFiller::Sequence >(filler); +} + +BOOST_PYTHON_MODULE(templates_mod) { + bp::numpy::initialize(); + bp::def("fill", &fill); +}