diff --git a/boost/numpy.hpp b/boost/numpy.hpp index 1efd4ce8..53b0cfca 100644 --- a/boost/numpy.hpp +++ b/boost/numpy.hpp @@ -23,9 +23,10 @@ namespace numpy { * It should probably be the first line inside BOOST_PYTHON_MODULE. * * @internal This just calls the Numpy C-API functions "import_array()" - * and "import_ufunc()". + * and "import_ufunc()", and then calls + * dtype::register_scalar_converters(). */ -void initialize(); +void initialize(bool register_scalar_converters=true); } // namespace boost::numpy } // namespace boost diff --git a/boost/numpy/dtype.hpp b/boost/numpy/dtype.hpp index 173802e3..5d5c9567 100644 --- a/boost/numpy/dtype.hpp +++ b/boost/numpy/dtype.hpp @@ -12,18 +12,14 @@ #include #include -namespace boost -{ -namespace numpy -{ +namespace boost { namespace numpy { /** * @brief A boost.python "object manager" (subclass of object) for numpy.dtype. * * @todo This could have a lot more interesting accessors. */ -class dtype : public python::object -{ +class dtype : public python::object { static python::detail::new_reference convert(python::object::object_cref arg, bool align); public: @@ -47,19 +43,59 @@ public: /// @brief Return the size of the data type in bytes. int get_itemsize() const; + /** + * @brief Register from-Python converters for NumPy's built-in array scalar types. + * + * This is usually called automatically by initialize(), and shouldn't be called twice + * (doing so just adds unused converters to the Boost.Python registry). + */ + static void register_scalar_converters(); + BOOST_PYTHON_FORWARD_OBJECT_CONSTRUCTORS(dtype, python::object); }; -} // namespace boost::numpy +namespace detail +{ -namespace python -{ -namespace converter -{ +template dtype get_int_dtype(); + +template dtype get_float_dtype(); + +template dtype get_complex_dtype(); + +template ::value> +struct builtin_dtype; + +template +struct builtin_dtype { + static dtype get() { return get_int_dtype< 8*sizeof(T), boost::is_unsigned::value >(); } +}; + +template <> +struct builtin_dtype { + static dtype get(); +}; + +template +struct builtin_dtype { + static dtype get() { return get_float_dtype< 8*sizeof(T) >(); } +}; + +template +struct builtin_dtype< std::complex, false > { + static dtype get() { return get_complex_dtype< 16*sizeof(T) >(); } +}; + +} // namespace detail + +template +inline dtype dtype::get_builtin() { return detail::builtin_dtype::get(); } + +}} // namespace boost::numpy + +namespace boost { namespace python { namespace converter { NUMPY_OBJECT_MANAGER_TRAITS(numpy::dtype); -} // namespace boost::python::converter -} // namespace boost::python -} // namespace boost +}}} // namespace boost::python::converter #endif // !BOOST_NUMPY_DTYPE_HPP_INCLUDED diff --git a/libs/numpy/src/dtype.cpp b/libs/numpy/src/dtype.cpp index e8431592..8c970a57 100644 --- a/libs/numpy/src/dtype.cpp +++ b/libs/numpy/src/dtype.cpp @@ -1,58 +1,78 @@ #define BOOST_NUMPY_INTERNAL #include -#define NUMPY_DTYPE_TRAITS_BUILTIN(ctype,code) \ -template <> struct dtype_traits \ -{ \ - static dtype get() \ - { \ - return dtype(python::detail::new_reference \ - (reinterpret_cast(PyArray_DescrFromType(code)))); \ - } \ -}; \ -template dtype dtype::get_builtin() +#define DTYPE_FROM_CODE(code) \ + dtype(python::detail::new_reference(reinterpret_cast(PyArray_DescrFromType(code)))) -#define NUMPY_DTYPE_TRAITS_COMPLEX(creal, ctype, code) \ -template <> struct dtype_traits< std::complex > \ -{ \ - static dtype get() \ - { \ - if (sizeof(ctype) != sizeof(std::complex)) \ - { \ - PyErr_SetString(PyExc_TypeError, "Cannot reinterpret std::complex as T[2]"); \ - python::throw_error_already_set(); \ - } \ - return dtype(python::detail::new_reference \ - (reinterpret_cast(PyArray_DescrFromType(code)))); \ - } \ -}; \ -template dtype dtype::get_builtin< std::complex >() +#define BUILTIN_INT_DTYPE(bits) \ + template <> struct builtin_int_dtype< bits, false > { \ + static dtype get() { return DTYPE_FROM_CODE(NPY_INT ## bits); } \ + }; \ + template <> struct builtin_int_dtype< bits, true > { \ + static dtype get() { return DTYPE_FROM_CODE(NPY_UINT ## bits); } \ + }; \ + template dtype get_int_dtype< bits, false >(); \ + template dtype get_int_dtype< bits, true >() -namespace boost -{ -namespace python -{ -namespace converter -{ +#define BUILTIN_FLOAT_DTYPE(bits) \ + template <> struct builtin_float_dtype< bits > { \ + static dtype get() { return DTYPE_FROM_CODE(NPY_FLOAT ## bits); } \ + }; \ + template dtype get_float_dtype< bits >() + +#define BUILTIN_COMPLEX_DTYPE(bits) \ + template <> struct builtin_complex_dtype< bits > { \ + static dtype get() { return DTYPE_FROM_CODE(NPY_COMPLEX ## bits); } \ + }; \ + template dtype get_complex_dtype< bits >() + +namespace boost { namespace python { namespace converter { NUMPY_OBJECT_MANAGER_TRAITS_IMPL(PyArrayDescr_Type, numpy::dtype) -} // namespace boost::python::converter -} // namespace boost::python +}}} // namespace boost::python::converter -namespace numpy -{ +namespace boost { namespace numpy { -template struct dtype_traits; +namespace detail { -python::detail::new_reference dtype::convert(python::object const & arg, bool align) -{ +dtype builtin_dtype::get() { return DTYPE_FROM_CODE(NPY_BOOL); } + +template struct builtin_int_dtype; +template struct builtin_float_dtype; +template struct builtin_complex_dtype; + +template dtype get_int_dtype() { + return builtin_int_dtype::get(); +} +template dtype get_float_dtype() { return builtin_float_dtype::get(); } +template dtype get_complex_dtype() { return builtin_complex_dtype::get(); } + +BUILTIN_INT_DTYPE(8); +BUILTIN_INT_DTYPE(16); +BUILTIN_INT_DTYPE(32); +BUILTIN_INT_DTYPE(64); +BUILTIN_FLOAT_DTYPE(32); +BUILTIN_FLOAT_DTYPE(64); +BUILTIN_COMPLEX_DTYPE(64); +BUILTIN_COMPLEX_DTYPE(128); +#if NPY_BITSOF_LONGDOUBLE > NPY_BITSOF_DOUBLE +template <> struct builtin_float_dtype< NPY_BITSOF_LONGDOUBLE > { + static dtype get() { return DTYPE_FROM_CODE(NPY_LONGDOUBLE); } +}; +template dtype get_float_dtype< NPY_BITSOF_LONGDOUBLE >(); +template <> struct builtin_complex_dtype< 2 * NPY_BITSOF_LONGDOUBLE > { + static dtype get() { return DTYPE_FROM_CODE(NPY_CLONGDOUBLE); } +}; +template dtype get_complex_dtype< 2 * NPY_BITSOF_LONGDOUBLE >(); +#endif + +} // namespace detail + +python::detail::new_reference dtype::convert(python::object const & arg, bool align) { PyArray_Descr* obj=NULL; - if (align) - { + if (align) { if (PyArray_DescrAlignConverter(arg.ptr(), &obj) < 0) python::throw_error_already_set(); - } - else - { + } else { if (PyArray_DescrConverter(arg.ptr(), &obj) < 0) python::throw_error_already_set(); } @@ -61,28 +81,72 @@ python::detail::new_reference dtype::convert(python::object const & arg, bool al int dtype::get_itemsize() const { return reinterpret_cast(ptr())->elsize;} -template -dtype dtype::get_builtin() { return dtype_traits::get(); } +namespace { -NUMPY_DTYPE_TRAITS_BUILTIN(bool, NPY_BOOL); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_ubyte, NPY_UBYTE); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_byte, NPY_BYTE); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_ushort, NPY_USHORT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_short, NPY_SHORT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_uint, NPY_UINT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_int, NPY_INT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_ulong, NPY_ULONG); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_long, NPY_LONG); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_longlong, NPY_LONGLONG); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_float, NPY_FLOAT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_double, NPY_DOUBLE); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_longdouble, NPY_LONGDOUBLE); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_cfloat, NPY_CFLOAT); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_cdouble, NPY_CDOUBLE); -NUMPY_DTYPE_TRAITS_BUILTIN(npy_clongdouble, NPY_CLONGDOUBLE); -NUMPY_DTYPE_TRAITS_COMPLEX(float, npy_cfloat, NPY_CFLOAT); -NUMPY_DTYPE_TRAITS_COMPLEX(double, npy_cdouble, NPY_CDOUBLE); -NUMPY_DTYPE_TRAITS_COMPLEX(long double, npy_clongdouble, NPY_CLONGDOUBLE); +namespace pyconv = boost::python::converter; + +template +class array_scalar_converter { +public: + + static PyTypeObject const * get_pytype() { + // This implementation depends on the fact that get_builtin returns pointers to objects + // NumPy has declared statically, and that the typeobj member also refers to a static + // object. That means we don't need to do any reference counting. + // In fact, I'm somewhat concerned that increasing the reference count of any of these + // might cause leaks, because I don't think Boost.Python ever decrements it, but it's + // probably a moot point if everything is actually static. + return reinterpret_cast(dtype::get_builtin().ptr())->typeobj; + } + + static void * convertible(PyObject * obj) { + if (obj->ob_type == get_pytype()) { + return obj; + } else { + return 0; + } + } + + static void convert(PyObject * obj, pyconv::rvalue_from_python_stage1_data* data) { + void * storage = reinterpret_cast*>(data)->storage.bytes; + // We assume std::complex is a "standard layout" here and elsewhere; not guaranteed by + // C++03 standard, but true in every known implementation (and guaranteed by C++11). + PyArray_ScalarAsCtype(obj, reinterpret_cast(storage)); + data->convertible = storage; + } + + static void declare() { + pyconv::registry::push_back( + &convertible, &convert, python::type_id() +#ifndef BOOST_PYTHON_NO_PY_SIGNATURES + , &get_pytype +#endif + ); + } + +}; + +} // anonymous + +void dtype::register_scalar_converters() { + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter::declare(); + array_scalar_converter< std::complex >::declare(); + array_scalar_converter< std::complex >::declare(); +#if NPY_BITSOF_LONGDOUBLE > NPY_BITSOF_DOUBLE + array_scalar_converter::declare(); + array_scalar_converter< std::complex >::declare(); +#endif +} } // namespace boost::numpy } // namespace boost diff --git a/libs/numpy/src/numpy.cpp b/libs/numpy/src/numpy.cpp index 9898a37f..3fab9d12 100644 --- a/libs/numpy/src/numpy.cpp +++ b/libs/numpy/src/numpy.cpp @@ -1,15 +1,18 @@ #define BOOST_NUMPY_INTERNAL_MAIN #include +#include namespace boost { namespace numpy { -void initialize() +void initialize(bool register_scalar_converters) { import_array(); import_ufunc(); + if (register_scalar_converters) + dtype::register_scalar_converters(); } } diff --git a/libs/numpy/test/SConscript b/libs/numpy/test/SConscript index 2cdf3e6d..2b5b1186 100644 --- a/libs/numpy/test/SConscript +++ b/libs/numpy/test/SConscript @@ -22,7 +22,7 @@ def PythonUnitTest(env, script, dependencies): env.Depends(run, dependencies) return run -for name in ("ufunc", "templates", "ndarray", "indexing", "shapes"): +for name in ("dtype", "ufunc", "templates", "ndarray", "indexing", "shapes"): mod = test_env.LoadableModule("%s_mod" % name, "%s_mod.cpp" % name, LDMODULEPREFIX="") test.extend(PythonUnitTest(test_env, "%s.py" % name, mod)) diff --git a/libs/numpy/test/dtype.py b/libs/numpy/test/dtype.py new file mode 100644 index 00000000..9e800fd9 --- /dev/null +++ b/libs/numpy/test/dtype.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +import dtype_mod +import unittest +import numpy + +class DtypeTestCase(unittest.TestCase): + + def testIntegers(self): + for bits in (8, 16, 32, 64): + s = getattr(numpy, "int%d" % bits) + u = getattr(numpy, "uint%d" % bits) + fs = getattr(dtype_mod, "accept_int%d" % bits) + fu = getattr(dtype_mod, "accept_uint%d" % bits) + self.assertEqual(fs(s(1)), numpy.dtype(s)) + self.assertEqual(fu(u(1)), numpy.dtype(u)) + # these should just use the regular Boost.Python converters + self.assertEqual(fs(True), numpy.dtype(s)) + self.assertEqual(fu(True), numpy.dtype(u)) + self.assertEqual(fs(int(1)), numpy.dtype(s)) + self.assertEqual(fu(int(1)), numpy.dtype(u)) + self.assertEqual(fs(long(1)), numpy.dtype(s)) + self.assertEqual(fu(long(1)), numpy.dtype(u)) + for name in ("bool_", "byte", "ubyte", "short", "ushort", "intc", "uintc"): + t = getattr(numpy, name) + ft = getattr(dtype_mod, "accept_%s" % name) + self.assertEqual(ft(t(1)), numpy.dtype(t)) + # these should just use the regular Boost.Python converters + self.assertEqual(ft(True), numpy.dtype(t)) + if name != "bool_": + self.assertEqual(ft(int(1)), numpy.dtype(t)) + self.assertEqual(ft(long(1)), numpy.dtype(t)) + + + def testFloats(self): + f = numpy.float32 + c = numpy.complex64 + self.assertEqual(dtype_mod.accept_float32(f(numpy.pi)), numpy.dtype(f)) + self.assertEqual(dtype_mod.accept_complex64(c(1+2j)), numpy.dtype(c)) + f = numpy.float64 + c = numpy.complex128 + self.assertEqual(dtype_mod.accept_float64(f(numpy.pi)), numpy.dtype(f)) + self.assertEqual(dtype_mod.accept_complex128(c(1+2j)), numpy.dtype(c)) + if hasattr(numpy, "longdouble"): + f = numpy.longdouble + c = numpy.clongdouble + self.assertEqual(dtype_mod.accept_longdouble(f(numpy.pi)), numpy.dtype(f)) + self.assertEqual(dtype_mod.accept_clongdouble(c(1+2j)), numpy.dtype(c)) + + +if __name__=="__main__": + unittest.main() diff --git a/libs/numpy/test/dtype_mod.cpp b/libs/numpy/test/dtype_mod.cpp new file mode 100644 index 00000000..ab16560a --- /dev/null +++ b/libs/numpy/test/dtype_mod.cpp @@ -0,0 +1,42 @@ +#include +#include + +namespace p = boost::python; +namespace np = boost::numpy; + +template +np::dtype accept(T) { + return np::dtype::get_builtin(); +} + + +BOOST_PYTHON_MODULE(dtype_mod) +{ + np::initialize(); + // integers, by number of bits + p::def("accept_int8", accept); + p::def("accept_uint8", accept); + p::def("accept_int16", accept); + p::def("accept_uint16", accept); + p::def("accept_int32", accept); + p::def("accept_uint32", accept); + p::def("accept_int64", accept); + p::def("accept_uint64", accept); + // integers, by C name according to NumPy + p::def("accept_bool_", accept); + p::def("accept_byte", accept); + p::def("accept_ubyte", accept); + p::def("accept_short", accept); + p::def("accept_ushort", accept); + p::def("accept_intc", accept); + p::def("accept_uintc", accept); + // floats and complex + p::def("accept_float32", accept); + p::def("accept_complex64", accept< std::complex >); + p::def("accept_float64", accept); + p::def("accept_complex128", accept< std::complex >); + if (sizeof(long double) > sizeof(double)) { + p::def("accept_longdouble", accept); + p::def("accept_clongdouble", accept< std::complex >); + } +}