2
0
mirror of https://github.com/boostorg/python.git synced 2026-01-23 17:52:17 +00:00

boost.python.numpy - moved dtype::invoke_matching_template into separate header, added similar code for invocation based on dimensionality

This commit is contained in:
Jim Bosch
2010-10-06 19:05:20 +00:00
parent 42ca807c82
commit ba1416fff0
6 changed files with 220 additions and 82 deletions

View File

@@ -6,12 +6,16 @@ class TestTemplates(unittest.TestCase):
def testTemplates(self):
for dtype in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
a1 = numpy.zeros((12,), dtype=dtype)
a2 = numpy.arange(12, dtype=dtype)
templates_mod.fill(a1)
self.assert_((a1 == a2).all())
v = numpy.arange(12, dtype=dtype)
for shape in ((12,), (4, 3), (2, 6)):
a1 = numpy.zeros(shape, dtype=dtype)
a2 = v.reshape(a1.shape)
templates_mod.fill(a1)
self.assert_((a1 == a2).all())
a1 = numpy.zeros((12,), dtype=numpy.float64)
self.assertRaises(TypeError, templates_mod.fill, a1)
a1 = numpy.zeros((12,2,3), dtype=numpy.float32)
self.assertRaises(TypeError, templates_mod.fill, a1)
if __name__=="__main__":
unittest.main()

View File

@@ -1,21 +1,48 @@
#include <boost/python/numpy/numpy.hpp>
#include <boost/mpl/vector.hpp>
#include <boost/mpl/vector_c.hpp>
namespace bp = boost::python;
struct ArrayFiller {
typedef boost::mpl::vector< short, int, float, std::complex<double> > Sequence;
typedef boost::mpl::vector< short, int, float, std::complex<double> > TypeSequence;
typedef boost::mpl::vector_c< int, 1, 2 > DimSequence;
template <typename T>
void apply() const {
char * p = argument.get_data();
int stride = argument.strides(0);
int size = argument.shape(0);
for (int n = 0; n != size; ++n, p += stride) {
*reinterpret_cast<T*>(p) = static_cast<T>(n);
struct nested {
void apply(boost::mpl::integral_c<int,1> * ) const {
char * p = argument.get_data();
int stride = argument.strides(0);
int size = argument.shape(0);
for (int n = 0; n != size; ++n, p += stride) {
*reinterpret_cast<T*>(p) = static_cast<T>(n);
}
}
}
void apply(boost::mpl::integral_c<int,2> * ) const {
char * row_p = argument.get_data();
int row_stride = argument.strides(0);
int col_stride = argument.strides(1);
int rows = argument.shape(0);
int cols = argument.shape(1);
int i = 0;
for (int n = 0; n != rows; ++n, row_p += row_stride) {
char * col_p = row_p;
for (int m = 0; m != cols; ++i, ++m, col_p += col_stride) {
*reinterpret_cast<T*>(col_p) = static_cast<T>(i);
}
}
}
explicit nested(bp::numpy::ndarray const & arg) : argument(arg) {}
bp::numpy::ndarray argument;
};
template <typename T>
nested<T> nest(T *) const { return nested<T>(argument); }
bp::numpy::ndarray argument;
@@ -25,7 +52,7 @@ struct ArrayFiller {
void fill(bp::numpy::ndarray const & arg) {
ArrayFiller filler(arg);
arg.get_dtype().invoke_matching_template< ArrayFiller::Sequence >(filler);
bp::numpy::invoke_matching_array< ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler);
}
BOOST_PYTHON_MODULE(templates_mod) {