From 313dcbb6281fcad2ac2e171418a4efb934a00447 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Sat, 21 Apr 2012 16:46:28 -0400 Subject: [PATCH] added as_matrix call policy --- boost/numpy/matrix.hpp | 14 ++++++++++++++ libs/numpy/test/ndarray.py | 10 ++++++++++ libs/numpy/test/ndarray_mod.cpp | 2 ++ 3 files changed, 26 insertions(+) diff --git a/boost/numpy/matrix.hpp b/boost/numpy/matrix.hpp index db5c8d24..408565cc 100644 --- a/boost/numpy/matrix.hpp +++ b/boost/numpy/matrix.hpp @@ -52,6 +52,20 @@ public: }; +/** + * @brief CallPolicies that causes a function that returns a numpy.ndarray to + * return a numpy.matrix instead. + */ +template +struct as_matrix : Base { + static PyObject * postcall(PyObject *, PyObject * result) { + python::object a = python::object(python::handle<>(result)); + numpy::matrix m(a, false); + Py_INCREF(m.ptr()); + return m.ptr(); + } +}; + } // namespace boost::numpy namespace python { diff --git a/libs/numpy/test/ndarray.py b/libs/numpy/test/ndarray.py index f989f0ac..ff844e69 100644 --- a/libs/numpy/test/ndarray.py +++ b/libs/numpy/test/ndarray.py @@ -16,6 +16,16 @@ class TestNdarray(unittest.TestCase): self.assertEqual(shape,a1.shape) self.assert_((a1 == a2).all()) + def testNdzeros_matrix(self): + for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): + dt = numpy.dtype(dtp) + shape = (6, 10) + a1 = ndarray_mod.zeros_matrix(shape, dt) + a2 = numpy.matrix(numpy.zeros(shape, dtype=dtp)) + self.assertEqual(shape,a1.shape) + self.assert_((a1 == a2).all()) + self.assertEqual(type(a1), type(a2)) + def testNdarray(self): a = range(0,60) for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): diff --git a/libs/numpy/test/ndarray_mod.cpp b/libs/numpy/test/ndarray_mod.cpp index df16b021..44f1bf17 100644 --- a/libs/numpy/test/ndarray_mod.cpp +++ b/libs/numpy/test/ndarray_mod.cpp @@ -24,10 +24,12 @@ np::ndarray c_empty(p::tuple shape, np::dtype dt) np::ndarray transpose(np::ndarray arr) { return arr.transpose();} np::ndarray squeeze(np::ndarray arr) { return arr.squeeze();} np::ndarray reshape(np::ndarray arr,p::tuple tup) { return arr.reshape(tup);} + BOOST_PYTHON_MODULE(ndarray_mod) { np::initialize(); p::def("zeros", zeros); + p::def("zeros_matrix", zeros, np::as_matrix<>()); p::def("array", array2); p::def("array", array1); p::def("empty", empty1);