mirror of
https://github.com/boostorg/python.git
synced 2026-01-24 06:02:14 +00:00
boost.python.numpy - added dtype template invoker
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
Import("bp_numpy_env")
|
||||
|
||||
ufunc_mod = bp_numpy_env.SharedLibrary("ufunc_mod", "ufunc_mod.cpp", SHLIBPREFIX="",
|
||||
LIBS="boost_python_numpy")
|
||||
ufunc_test = bp_numpy_env.PythonUnitTest("ufunc.py", ufunc_mod)
|
||||
test = []
|
||||
|
||||
for name in ("ufunc", "templates"):
|
||||
mod = bp_numpy_env.SharedLibrary("%s_mod" % name, "%s_mod.cpp" % name, SHLIBPREFIX="",
|
||||
LIBS="boost_python_numpy")
|
||||
test.extend(
|
||||
bp_numpy_env.PythonUnitTest("%s.py" % name, mod)
|
||||
)
|
||||
|
||||
test = ufunc_test
|
||||
Return("test")
|
||||
|
||||
17
libs/python/numpy/test/templates.py
Executable file
17
libs/python/numpy/test/templates.py
Executable file
@@ -0,0 +1,17 @@
|
||||
import templates_mod
|
||||
import unittest
|
||||
import numpy
|
||||
|
||||
class TestTemplates(unittest.TestCase):
|
||||
|
||||
def testTemplates(self):
|
||||
for dtype in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
|
||||
a1 = numpy.zeros((12,), dtype=dtype)
|
||||
a2 = numpy.arange(12, dtype=dtype)
|
||||
templates_mod.fill(a1)
|
||||
self.assert_((a1 == a2).all())
|
||||
a1 = numpy.zeros((12,), dtype=numpy.float64)
|
||||
self.assertRaises(TypeError, templates_mod.fill, a1)
|
||||
|
||||
if __name__=="__main__":
|
||||
unittest.main()
|
||||
34
libs/python/numpy/test/templates_mod.cpp
Normal file
34
libs/python/numpy/test/templates_mod.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
#include <boost/python/numpy/numpy.hpp>
|
||||
#include <boost/mpl/vector.hpp>
|
||||
|
||||
namespace bp = boost::python;
|
||||
|
||||
struct ArrayFiller {
|
||||
|
||||
typedef boost::mpl::vector< short, int, float, std::complex<double> > Sequence;
|
||||
|
||||
template <typename T>
|
||||
void apply() const {
|
||||
char * p = argument.get_data();
|
||||
int stride = argument.strides(0);
|
||||
int size = argument.shape(0);
|
||||
for (int n = 0; n != size; ++n, p += stride) {
|
||||
*reinterpret_cast<T*>(p) = static_cast<T>(n);
|
||||
}
|
||||
}
|
||||
|
||||
bp::numpy::ndarray argument;
|
||||
|
||||
explicit ArrayFiller(bp::numpy::ndarray const & arg) : argument(arg) {}
|
||||
|
||||
};
|
||||
|
||||
void fill(bp::numpy::ndarray const & arg) {
|
||||
ArrayFiller filler(arg);
|
||||
arg.get_dtype().invoke_matching_template< ArrayFiller::Sequence >(filler);
|
||||
}
|
||||
|
||||
BOOST_PYTHON_MODULE(templates_mod) {
|
||||
bp::numpy::initialize();
|
||||
bp::def("fill", &fill);
|
||||
}
|
||||
Reference in New Issue
Block a user