support count keyword on the python side

This commit is contained in:
Hans Dembinski
2017-11-09 12:10:16 +01:00
parent b40a68bb54
commit 0923ef3be2
4 changed files with 66 additions and 34 deletions

View File

@@ -179,32 +179,33 @@ python::object histogram_init(python::tuple args, python::dict kwargs) {
return pyinit(h);
}
template <typename T>
struct fetcher {
long n = 0;
long n = -1;
union {
double value = 0;
const double* carray;
T value = 0;
const T* carray;
};
python::object keep_alive;
void assign(python::object o) {
// skipping check for currently held type, since it is always value
python::extract<double> get_double(o);
if (get_double.check()) {
value = get_double();
python::extract<T> get_value(o);
if (get_value.check()) {
value = get_value();
n = 0;
return;
}
#ifdef HAVE_NUMPY
np::ndarray a = np::from_object(o, np::dtype::get_builtin<double>(), 1);
carray = reinterpret_cast<const double*>(a.get_data());
np::ndarray a = np::from_object(o, np::dtype::get_builtin<T>(), 1);
carray = reinterpret_cast<const T*>(a.get_data());
n = a.shape(0);
keep_alive = a; // this may be a temporary object
return;
#endif
throw std::invalid_argument("argument must be a number");
}
double get(long i) const noexcept {
T get(long i) const noexcept {
if (n > 0)
return carray[i];
return value;
@@ -228,7 +229,7 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
python::throw_error_already_set();
}
fetcher fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
fetcher<double> fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
long n = 0;
for (auto d = 0u; d < dim; ++d) {
fetch[d].assign(args[1 + d]);
@@ -241,22 +242,37 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
}
}
fetcher fetch_weight;
bool use_weight = false;
fetcher<double> fetch_weight;
fetcher<unsigned> fetch_count;
const auto nkwargs = python::len(kwargs);
if (nkwargs > 0) {
if (nkwargs > 1 || !kwargs.has_key("weight")) {
PyErr_SetString(PyExc_RuntimeError, "only keyword weight allowed");
const bool use_weight = kwargs.has_key("weight");
const bool use_count = kwargs.has_key("count");
if (nkwargs > 1 || (use_weight == use_count)) { // may not be both true or false
PyErr_SetString(PyExc_RuntimeError, "only keyword weight or count allowed");
python::throw_error_already_set();
}
use_weight = true;
fetch_weight.assign(kwargs.get("weight"));
if (fetch_weight.n > 0) {
if (n > 0 && fetch_weight.n != n) {
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
python::throw_error_already_set();
if (use_weight) {
fetch_weight.assign(kwargs.get("weight"));
if (fetch_weight.n > 0) {
if (n > 0 && fetch_weight.n != n) {
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
python::throw_error_already_set();
}
n = fetch_weight.n;
}
}
if (use_count) {
fetch_count.assign(kwargs.get("count"));
if (fetch_count.n > 0) {
if (n > 0 && fetch_count.n != n) {
PyErr_SetString(PyExc_ValueError, "length of count sequence does not match");
python::throw_error_already_set();
}
n = fetch_count.n;
}
n = fetch_weight.n;
}
}
@@ -265,11 +281,12 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
for (auto i = 0l; i < n; ++i) {
for (auto d = 0u; d < dim; ++d)
v[d] = fetch[d].get(i);
if (use_weight) {
if (fetch_weight.n >= 0)
self.fill(v, v + dim, weight(fetch_weight.get(i)));
} else {
else if (fetch_count.n >= 0)
self.fill(v, v + dim, count(fetch_count.get(i)));
else
self.fill(v, v + dim);
}
}
return python::object();