From ecf05c4a90b05b031634b7c793a9e25d4e4fe611 Mon Sep 17 00:00:00 2001 From: Mark Borgerding Date: Wed, 20 Sep 2017 09:39:46 -0400 Subject: [PATCH] ndarray.shape(k),strides(k) act more like their python counterparts (negative indexing, bounds checking) (issue #157) --- include/boost/python/numpy/ndarray.hpp | 8 +++---- src/numpy/ndarray.cpp | 24 +++++++++++++++++++ test/numpy/ndarray.cpp | 5 ++++ test/numpy/ndarray.py | 33 ++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 4 deletions(-) diff --git a/include/boost/python/numpy/ndarray.hpp b/include/boost/python/numpy/ndarray.hpp index 98a4cb15..2cb3b509 100644 --- a/include/boost/python/numpy/ndarray.hpp +++ b/include/boost/python/numpy/ndarray.hpp @@ -86,11 +86,11 @@ public: /// @brief Copy the scalar (deep for all non-object fields). ndarray copy() const; - /// @brief Return the size of the nth dimension. - Py_intptr_t shape(int n) const { return get_shape()[n]; } + /// @brief Return the size of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1 ] + Py_intptr_t shape(int n) const; - /// @brief Return the stride of the nth dimension. - Py_intptr_t strides(int n) const { return get_strides()[n]; } + /// @brief Return the stride of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1] + Py_intptr_t strides(int n) const; /** * @brief Return the array's raw data pointer. diff --git a/src/numpy/ndarray.cpp b/src/numpy/ndarray.cpp index 8ae67b89..625c43b3 100644 --- a/src/numpy/ndarray.cpp +++ b/src/numpy/ndarray.cpp @@ -138,6 +138,30 @@ ndarray from_data_impl(void * data, } // namespace detail +namespace { + int normalize_index(int n,int nlim) // wraps [-nlim:nlim) into [0:nlim), throw IndexError otherwise + { + if (n<0) + n += nlim; // negative indices work backwards from end + if (n < 0 || n >= nlim) + { + PyErr_SetObject(PyExc_IndexError, Py_None); + throw_error_already_set(); + } + return n; + } +} + +Py_intptr_t ndarray::shape(int n) const +{ + return get_shape()[normalize_index(n,get_nd())]; +} + +Py_intptr_t ndarray::strides(int n) const +{ + return get_strides()[normalize_index(n,get_nd())]; +} + ndarray ndarray::view(dtype const & dt) const { return ndarray(python::detail::new_reference diff --git a/test/numpy/ndarray.cpp b/test/numpy/ndarray.cpp index 808872e8..75a10104 100644 --- a/test/numpy/ndarray.cpp +++ b/test/numpy/ndarray.cpp @@ -31,6 +31,9 @@ np::ndarray transpose(np::ndarray arr) { return arr.transpose();} np::ndarray squeeze(np::ndarray arr) { return arr.squeeze();} np::ndarray reshape(np::ndarray arr,p::tuple tup) { return arr.reshape(tup);} +Py_intptr_t shape_index(np::ndarray arr,int k) { return arr.shape(k); } +Py_intptr_t strides_index(np::ndarray arr,int k) { return arr.strides(k); } + BOOST_PYTHON_MODULE(ndarray_ext) { np::initialize(); @@ -43,4 +46,6 @@ BOOST_PYTHON_MODULE(ndarray_ext) p::def("transpose", transpose); p::def("squeeze", squeeze); p::def("reshape", reshape); + p::def("shape_index", shape_index); + p::def("strides_index", strides_index); } diff --git a/test/numpy/ndarray.py b/test/numpy/ndarray.py index fb92a2a2..2acc384a 100644 --- a/test/numpy/ndarray.py +++ b/test/numpy/ndarray.py @@ -75,5 +75,38 @@ class TestNdarray(unittest.TestCase): a2 = ndarray_ext.reshape(a1,(1,4)) self.assertEqual(a2.shape,(1,4)) + def testShapeIndex(self): + a = numpy.arange(24) + a.shape = (1,2,3,4) + def shape_check(i): + print(i) + self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] ) + for i in range(4): + shape_check(i) + for i in range(-1,-5,-1): + shape_check(i) + try: + ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError + self.assertTrue(False) + except IndexError: + pass + + def testStridesIndex(self): + a = numpy.arange(24) + a.shape = (1,2,3,4) + def strides_check(i): + print(i) + self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] ) + for i in range(4): + strides_check(i) + for i in range(-1,-5,-1): + strides_check(i) + try: + ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError + self.assertTrue(False) + except IndexError: + pass + + if __name__=="__main__": unittest.main()