diff --git a/examples/guide_listing_8.py b/examples/guide_listing_8.py index 4b5ae024..924c64f2 100644 --- a/examples/guide_listing_8.py +++ b/examples/guide_listing_8.py @@ -8,5 +8,5 @@ for i in range(10): h.fill(i) # do this instead, it is very fast -v = np.arange(10) +v = np.arange(10, dtype=float) h.fill(v) # fills the histogram with each value in the array diff --git a/src/python/histogram.cpp b/src/python/histogram.cpp index 3f5dbc0b..458a3b3d 100644 --- a/src/python/histogram.cpp +++ b/src/python/histogram.cpp @@ -183,41 +183,37 @@ python::object histogram_init(python::tuple args, python::dict kwargs) { } struct fetcher { - virtual ~fetcher() {} - virtual double at(long) const = 0; long n = 0; -}; - -#ifdef HAVE_NUMPY -struct fetcher_seq : public fetcher { - fetcher_seq(python::object o) - : array(np::from_object(o, np::dtype::get_builtin(), 1)) { - fetcher::n = array.shape(0); + union { + double value = 0; + const double* array; + }; + void assign(python::object o) { + // skipping check for currently held type, since it is always value + python::extract get_double(o); + if (get_double.check()) { + value = get_double(); + n = 0; + return; } - ~fetcher_seq() {} - double at(long i) const { - return reinterpret_cast(array.get_data())[i]; - } - np::ndarray array; -}; -#endif - -struct fetcher_val : public fetcher { - fetcher_val(double val) - : value(val) {} - double at(long) const { return value; } - double value; -}; - -std::unique_ptr make_fetcher(python::object o) { - python::extract get_double(o); - if (get_double.check()) - return std::unique_ptr(new fetcher_val(get_double())); #ifdef HAVE_NUMPY - return std::unique_ptr(new fetcher_seq(o)); + np::ndarray a = python::extract(o); + if (a.get_nd() != 1) + throw std::invalid_argument("array must be 1 dimensional"); + if (a.get_dtype() != np::dtype::get_builtin()) + throw std::invalid_argument("array dtype must be double"); + array = reinterpret_cast(a.get_data()); + n = a.shape(0); + return; #endif - throw std::invalid_argument("python object is neither sequence nor number"); -} + throw std::invalid_argument("argument must be a number"); + } + double get(long i) const noexcept { + if (n > 0) + return array[i]; + return value; + } +}; python::object histogram_fill(python::tuple args, python::dict kwargs) { const auto nargs = python::len(args); @@ -236,35 +232,35 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) { python::throw_error_already_set(); } - std::unique_ptr fetch[BOOST_HISTOGRAM_AXIS_LIMIT]; + fetcher fetch[BOOST_HISTOGRAM_AXIS_LIMIT]; long n = 0; for (auto d = 0u; d < dim; ++d) { - fetch[d] = make_fetcher(args[1 + d]); - const auto on = fetch[d]->n; - if (on > 0) { - if (n && on != n) { + fetch[d].assign(args[1 + d]); + if (fetch[d].n > 0) { + if (n && fetch[d].n != n) { PyErr_SetString(PyExc_ValueError, "lengths of sequences do not match"); python::throw_error_already_set(); } - n = on; + n = fetch[d].n; } } - std::unique_ptr fetch_weight; + fetcher fetch_weight; + bool use_weight = false; const auto nkwargs = python::len(kwargs); if (nkwargs > 0) { if (nkwargs > 1 || !kwargs.has_key("weight")) { PyErr_SetString(PyExc_RuntimeError, "only keyword weight allowed"); python::throw_error_already_set(); } - fetch_weight = make_fetcher(kwargs.get("weight")); - const auto on = fetch_weight->n; - if (on > 0) { - if (n && on != n) { + use_weight = true; + fetch_weight.assign(kwargs.get("weight")); + if (fetch_weight.n > 0) { + if (n && fetch_weight.n != n) { PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match"); python::throw_error_already_set(); } - n = on; + n = fetch_weight.n; } } @@ -272,9 +268,9 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) { if (!n) ++n; for (auto i = 0l; i < n; ++i) { for (auto d = 0u; d < dim; ++d) - v[d] = fetch[d]->at(i); - if (fetch_weight) { - self.fill(v, v + dim, weight(fetch_weight->at(i))); + v[d] = fetch[d].get(i); + if (use_weight) { + self.fill(v, v + dim, weight(fetch_weight.get(i))); } else { self.fill(v, v + dim); } diff --git a/test/python_suite_test.py b/test/python_suite_test.py index d8137351..60b8ed82 100644 --- a/test/python_suite_test.py +++ b/test/python_suite_test.py @@ -739,11 +739,9 @@ class test_histogram(unittest.TestCase): @unittest.skipUnless(have_numpy, "requires build with numpy-support") def test_fill_with_numpy_array_0(self): - ar = lambda *args: numpy.array(args) + ar = lambda *args: numpy.array(args, dtype=float) a = histogram(integer(0, 3, uoflow=False)) - a.fill(ar(-1, 0, 1, 2, 1, 4)) - a.fill((-1, 0)) - a.fill([1, 2]) + a.fill(ar(-1, 0, 1, 2, 1, 4, -1, 0, 1, 2)) self.assertEqual(a.value(0), 2) self.assertEqual(a.value(1), 3) self.assertEqual(a.value(2), 2) @@ -752,7 +750,7 @@ class test_histogram(unittest.TestCase): a.fill(numpy.empty((2, 2))) with self.assertRaises(ValueError): a.fill(numpy.empty(2), 1) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): a.fill("abc") a = histogram(integer(0, 2, uoflow=False), @@ -765,11 +763,10 @@ class test_histogram(unittest.TestCase): self.assertEqual(a.value(1, 1), 0) with self.assertRaises(ValueError): - a.fill((1, 2, 3)) + a.fill(ar(1, 2, 3)) a = histogram(integer(0, 3, uoflow=False)) - a.fill([0, 0, 1, 2]) - a.fill((1, 0, 2, 2)) + a.fill(ar(0, 0, 1, 2, 1, 0, 2, 2)) self.assertEqual(a.value(0), 3) self.assertEqual(a.value(1), 2) self.assertEqual(a.value(2), 3) @@ -777,12 +774,12 @@ class test_histogram(unittest.TestCase): @unittest.skipUnless(have_numpy, "requires build with numpy-support") def test_fill_with_numpy_array_1(self): - ar = lambda *args: numpy.array(args) + ar = lambda *args: numpy.array(args, dtype=float) a = histogram(integer(0, 3, uoflow=True)) - v = numpy.array([-1, 0, 1, 2, 3, 4]) - w = numpy.array([ 2, 3, 4, 5, 6, 7]) + v = ar(-1, 0, 1, 2, 3, 4) + w = ar( 2, 3, 4, 5, 6, 7) a.fill(v, weight=w) - a.fill([0, 1], weight=[2, 3]) + a.fill(ar(0, 1), weight=ar(2, 3)) self.assertEqual(a.value(-1), 2) self.assertEqual(a.value(0), 5) self.assertEqual(a.value(1), 7) @@ -791,22 +788,22 @@ class test_histogram(unittest.TestCase): self.assertEqual(a.variance(0), 13) self.assertEqual(a.variance(1), 25) self.assertEqual(a.variance(2), 25) - a.fill([1, 2], weight=1) - a.fill(0, weight=[1, 2]) + a.fill(ar(1, 2), weight=1) + a.fill(0, weight=ar(1, 2)) self.assertEqual(a.value(0), 8) self.assertEqual(a.value(1), 8) self.assertEqual(a.value(2), 6) with self.assertRaises(RuntimeError): - a.fill([1, 2], foo=[1, 1]) + a.fill(ar(1, 2), foo=ar(1, 1)) with self.assertRaises(ValueError): - a.fill([1, 2], weight=[1]) - with self.assertRaises(ValueError): - a.fill([1, 2], weight="ab") + a.fill(ar(1, 2), weight=ar(1)) + with self.assertRaises(TypeError): + a.fill(ar(1, 2), weight="ab") with self.assertRaises(RuntimeError): - a.fill([1, 2], weight=[1, 1], foo=1) + a.fill(ar(1, 2), weight=ar(1, 1), foo=1) with self.assertRaises(ValueError): - a.fill([1, 2], weight=[[1, 1], [2, 2]]) + a.fill(ar(1, 2), weight=ar([1, 1], [2, 2])) a = histogram(integer(0, 2, uoflow=False), regular(2, 0, 2, uoflow=False)) @@ -816,8 +813,7 @@ class test_histogram(unittest.TestCase): self.assertEqual(a.value(1, 0), 1) self.assertEqual(a.value(1, 1), 0) a = histogram(integer(0, 3, uoflow=False)) - a.fill([0, 0, 1, 2]) - a.fill((1, 0, 2, 2)) + a.fill(ar(0, 0, 1, 2, 1, 0, 2, 2)) self.assertEqual(a.value(0), 3) self.assertEqual(a.value(1), 2) self.assertEqual(a.value(2), 3)