From 03c8436f4c77ec2760fe5ca86301d0907031dc18 Mon Sep 17 00:00:00 2001 From: Hans Dembinski Date: Sat, 4 Nov 2017 00:32:22 +0100 Subject: [PATCH] fix --- include/boost/histogram/axis_ostream_operators.hpp | 1 - src/python/axis.cpp | 13 +++++++++---- test/python_suite_test.py | 14 +++++++++++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/include/boost/histogram/axis_ostream_operators.hpp b/include/boost/histogram/axis_ostream_operators.hpp index dd5e0753..96d8ca67 100644 --- a/include/boost/histogram/axis_ostream_operators.hpp +++ b/include/boost/histogram/axis_ostream_operators.hpp @@ -16,7 +16,6 @@ namespace boost { namespace histogram { - namespace axis { namespace detail { diff --git a/src/python/axis.cpp b/src/python/axis.cpp index 6dcedd73..96cdb69a 100644 --- a/src/python/axis.cpp +++ b/src/python/axis.cpp @@ -207,11 +207,16 @@ python::object make_regular(unsigned bin, double lower, double upper, return object(axis::regular( bin, lower, upper, label, uoflow)); else if (trans.substr(0, 3) == "pow") { - const double val = lexical_cast(trans.substr(4, trans.size()-1)); - return object(axis::regular( - bin, lower, upper, label, uoflow, axis::transform::pow(val))); + try { + const double val = lexical_cast(trans.substr(4, trans.size()-5)); + return object(axis::regular( + bin, lower, upper, label, uoflow, axis::transform::pow(val))); + } catch (...) { + PyErr_SetString(PyExc_ValueError, "pow argument not recognized"); + throw_error_already_set(); + } } - PyErr_SetString(PyExc_KeyError, "transform signature not recognized"); + PyErr_SetString(PyExc_ValueError, "transform signature not recognized"); throw_error_already_set(); return object(); } diff --git a/test/python_suite_test.py b/test/python_suite_test.py index 60b8ed82..64f83958 100644 --- a/test/python_suite_test.py +++ b/test/python_suite_test.py @@ -29,6 +29,11 @@ class test_regular(unittest.TestCase): regular(1, 1.0, 2.0, label="ra") regular(1, 1.0, 2.0, uoflow=False) regular(1, 1.0, 2.0, label="ra", uoflow=False) + regular(1, 1.0, 2.0, trans="log") + regular(1, 1.0, 2.0, trans="sqrt") + regular(1, 0.5, 1.0, trans="cos") + regular(1, 1.0, 2.0, trans="pow(1.5)") + regular(1, 1.0, 2.0, trans="pow[2]") with self.assertRaises(TypeError): regular() with self.assertRaises(TypeError): @@ -49,6 +54,12 @@ class test_regular(unittest.TestCase): regular(1, 1.0, 2.0, label="ra", uoflow="True") with self.assertRaises(TypeError): regular(1, 1.0, 2.0, bad_keyword="ra") + with self.assertRaises(ValueError): + regular(1, 1.0, 2.0, trans="bla") + with self.assertRaises(ValueError): + regular(1, 1.0, 2.0, trans="pow") + with self.assertRaises(ValueError): + regular(1, 1.0, 2.0, trans="pow()") a = regular(4, 1.0, 2.0) self.assertEqual(a, regular(4, 1.0, 2.0)) self.assertNotEqual(a, regular(3, 1.0, 2.0)) @@ -63,7 +74,8 @@ class test_regular(unittest.TestCase): for s in ("regular(4, 1.1, 2.2)", "regular(4, 1.1, 2.2, label='ra')", "regular(4, 1.1, 2.2, uoflow=False)", - "regular(4, 1.1, 2.2, label='ra', uoflow=False)"): + "regular(4, 1.1, 2.2, label='ra', uoflow=False)", + "regular(4, 1.1, 2.2, trans='log')"): self.assertEqual(str(eval(s)), s) def test_getitem(self):