From 2f7742ffeca81a8163c76a8d8dabca9245970e54 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Tue, 18 Sep 2012 23:13:29 -0400 Subject: [PATCH] add more permissive equivalence test for dtypes, start using it in tests --- boost/numpy/dtype.hpp | 10 +++++++++ libs/numpy/src/dtype.cpp | 7 +++++++ libs/numpy/test/dtype.py | 39 +++++++++++++++++++---------------- libs/numpy/test/dtype_mod.cpp | 3 ++- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/boost/numpy/dtype.hpp b/boost/numpy/dtype.hpp index 5d5c9567..f3af59fa 100644 --- a/boost/numpy/dtype.hpp +++ b/boost/numpy/dtype.hpp @@ -43,6 +43,14 @@ public: /// @brief Return the size of the data type in bytes. int get_itemsize() const; + /** + * @brief Compare two dtypes for equivalence. + * + * This is more permissive than equality tests. For instance, if long and int are the same + * size, the dtypes corresponding to each will be equivalent, but not equal. + */ + friend bool equivalent(dtype const & a, dtype const & b); + /** * @brief Register from-Python converters for NumPy's built-in array scalar types. * @@ -55,6 +63,8 @@ public: }; +bool equivalent(dtype const & a, dtype const & b); + namespace detail { diff --git a/libs/numpy/src/dtype.cpp b/libs/numpy/src/dtype.cpp index 8c970a57..013c4c31 100644 --- a/libs/numpy/src/dtype.cpp +++ b/libs/numpy/src/dtype.cpp @@ -81,6 +81,13 @@ python::detail::new_reference dtype::convert(python::object const & arg, bool al int dtype::get_itemsize() const { return reinterpret_cast(ptr())->elsize;} +bool equivalent(dtype const & a, dtype const & b) { + return PyArray_EquivTypes( + reinterpret_cast(a.ptr()), + reinterpret_cast(b.ptr()) + ); +} + namespace { namespace pyconv = boost::python::converter; diff --git a/libs/numpy/test/dtype.py b/libs/numpy/test/dtype.py index 9e800fd9..05652905 100644 --- a/libs/numpy/test/dtype.py +++ b/libs/numpy/test/dtype.py @@ -6,46 +6,49 @@ import numpy class DtypeTestCase(unittest.TestCase): + def assertEquivalent(self, a, b): + return self.assert_(dtype_mod.equivalent(a, b), "%r is not equivalent to %r") + def testIntegers(self): for bits in (8, 16, 32, 64): s = getattr(numpy, "int%d" % bits) u = getattr(numpy, "uint%d" % bits) fs = getattr(dtype_mod, "accept_int%d" % bits) fu = getattr(dtype_mod, "accept_uint%d" % bits) - self.assertEqual(fs(s(1)), numpy.dtype(s)) - self.assertEqual(fu(u(1)), numpy.dtype(u)) + self.assertEquivalent(fs(s(1)), numpy.dtype(s)) + self.assertEquivalent(fu(u(1)), numpy.dtype(u)) # these should just use the regular Boost.Python converters - self.assertEqual(fs(True), numpy.dtype(s)) - self.assertEqual(fu(True), numpy.dtype(u)) - self.assertEqual(fs(int(1)), numpy.dtype(s)) - self.assertEqual(fu(int(1)), numpy.dtype(u)) - self.assertEqual(fs(long(1)), numpy.dtype(s)) - self.assertEqual(fu(long(1)), numpy.dtype(u)) + self.assertEquivalent(fs(True), numpy.dtype(s)) + self.assertEquivalent(fu(True), numpy.dtype(u)) + self.assertEquivalent(fs(int(1)), numpy.dtype(s)) + self.assertEquivalent(fu(int(1)), numpy.dtype(u)) + self.assertEquivalent(fs(long(1)), numpy.dtype(s)) + self.assertEquivalent(fu(long(1)), numpy.dtype(u)) for name in ("bool_", "byte", "ubyte", "short", "ushort", "intc", "uintc"): t = getattr(numpy, name) ft = getattr(dtype_mod, "accept_%s" % name) - self.assertEqual(ft(t(1)), numpy.dtype(t)) + self.assertEquivalent(ft(t(1)), numpy.dtype(t)) # these should just use the regular Boost.Python converters - self.assertEqual(ft(True), numpy.dtype(t)) + self.assertEquivalent(ft(True), numpy.dtype(t)) if name != "bool_": - self.assertEqual(ft(int(1)), numpy.dtype(t)) - self.assertEqual(ft(long(1)), numpy.dtype(t)) + self.assertEquivalent(ft(int(1)), numpy.dtype(t)) + self.assertEquivalent(ft(long(1)), numpy.dtype(t)) def testFloats(self): f = numpy.float32 c = numpy.complex64 - self.assertEqual(dtype_mod.accept_float32(f(numpy.pi)), numpy.dtype(f)) - self.assertEqual(dtype_mod.accept_complex64(c(1+2j)), numpy.dtype(c)) + self.assertEquivalent(dtype_mod.accept_float32(f(numpy.pi)), numpy.dtype(f)) + self.assertEquivalent(dtype_mod.accept_complex64(c(1+2j)), numpy.dtype(c)) f = numpy.float64 c = numpy.complex128 - self.assertEqual(dtype_mod.accept_float64(f(numpy.pi)), numpy.dtype(f)) - self.assertEqual(dtype_mod.accept_complex128(c(1+2j)), numpy.dtype(c)) + self.assertEquivalent(dtype_mod.accept_float64(f(numpy.pi)), numpy.dtype(f)) + self.assertEquivalent(dtype_mod.accept_complex128(c(1+2j)), numpy.dtype(c)) if hasattr(numpy, "longdouble"): f = numpy.longdouble c = numpy.clongdouble - self.assertEqual(dtype_mod.accept_longdouble(f(numpy.pi)), numpy.dtype(f)) - self.assertEqual(dtype_mod.accept_clongdouble(c(1+2j)), numpy.dtype(c)) + self.assertEquivalent(dtype_mod.accept_longdouble(f(numpy.pi)), numpy.dtype(f)) + self.assertEquivalent(dtype_mod.accept_clongdouble(c(1+2j)), numpy.dtype(c)) if __name__=="__main__": diff --git a/libs/numpy/test/dtype_mod.cpp b/libs/numpy/test/dtype_mod.cpp index ab16560a..2ef52ca6 100644 --- a/libs/numpy/test/dtype_mod.cpp +++ b/libs/numpy/test/dtype_mod.cpp @@ -9,10 +9,11 @@ np::dtype accept(T) { return np::dtype::get_builtin(); } - BOOST_PYTHON_MODULE(dtype_mod) { np::initialize(); + // wrap dtype equivalence test, since it isn't available in Python API. + p::def("equivalent", np::equivalent); // integers, by number of bits p::def("accept_int8", accept); p::def("accept_uint8", accept);