mirror of
https://github.com/boostorg/histogram.git
synced 2026-01-30 07:52:11 +00:00
support count keyword on the python side
This commit is contained in:
@@ -179,32 +179,33 @@ python::object histogram_init(python::tuple args, python::dict kwargs) {
|
||||
return pyinit(h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct fetcher {
|
||||
long n = 0;
|
||||
long n = -1;
|
||||
union {
|
||||
double value = 0;
|
||||
const double* carray;
|
||||
T value = 0;
|
||||
const T* carray;
|
||||
};
|
||||
python::object keep_alive;
|
||||
|
||||
void assign(python::object o) {
|
||||
// skipping check for currently held type, since it is always value
|
||||
python::extract<double> get_double(o);
|
||||
if (get_double.check()) {
|
||||
value = get_double();
|
||||
python::extract<T> get_value(o);
|
||||
if (get_value.check()) {
|
||||
value = get_value();
|
||||
n = 0;
|
||||
return;
|
||||
}
|
||||
#ifdef HAVE_NUMPY
|
||||
np::ndarray a = np::from_object(o, np::dtype::get_builtin<double>(), 1);
|
||||
carray = reinterpret_cast<const double*>(a.get_data());
|
||||
np::ndarray a = np::from_object(o, np::dtype::get_builtin<T>(), 1);
|
||||
carray = reinterpret_cast<const T*>(a.get_data());
|
||||
n = a.shape(0);
|
||||
keep_alive = a; // this may be a temporary object
|
||||
return;
|
||||
#endif
|
||||
throw std::invalid_argument("argument must be a number");
|
||||
}
|
||||
double get(long i) const noexcept {
|
||||
T get(long i) const noexcept {
|
||||
if (n > 0)
|
||||
return carray[i];
|
||||
return value;
|
||||
@@ -228,7 +229,7 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
|
||||
fetcher fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
|
||||
fetcher<double> fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
|
||||
long n = 0;
|
||||
for (auto d = 0u; d < dim; ++d) {
|
||||
fetch[d].assign(args[1 + d]);
|
||||
@@ -241,22 +242,37 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
}
|
||||
}
|
||||
|
||||
fetcher fetch_weight;
|
||||
bool use_weight = false;
|
||||
fetcher<double> fetch_weight;
|
||||
fetcher<unsigned> fetch_count;
|
||||
const auto nkwargs = python::len(kwargs);
|
||||
if (nkwargs > 0) {
|
||||
if (nkwargs > 1 || !kwargs.has_key("weight")) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "only keyword weight allowed");
|
||||
const bool use_weight = kwargs.has_key("weight");
|
||||
const bool use_count = kwargs.has_key("count");
|
||||
if (nkwargs > 1 || (use_weight == use_count)) { // may not be both true or false
|
||||
PyErr_SetString(PyExc_RuntimeError, "only keyword weight or count allowed");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
use_weight = true;
|
||||
fetch_weight.assign(kwargs.get("weight"));
|
||||
if (fetch_weight.n > 0) {
|
||||
if (n > 0 && fetch_weight.n != n) {
|
||||
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
|
||||
python::throw_error_already_set();
|
||||
|
||||
if (use_weight) {
|
||||
fetch_weight.assign(kwargs.get("weight"));
|
||||
if (fetch_weight.n > 0) {
|
||||
if (n > 0 && fetch_weight.n != n) {
|
||||
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
n = fetch_weight.n;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_count) {
|
||||
fetch_count.assign(kwargs.get("count"));
|
||||
if (fetch_count.n > 0) {
|
||||
if (n > 0 && fetch_count.n != n) {
|
||||
PyErr_SetString(PyExc_ValueError, "length of count sequence does not match");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
n = fetch_count.n;
|
||||
}
|
||||
n = fetch_weight.n;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,11 +281,12 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
for (auto i = 0l; i < n; ++i) {
|
||||
for (auto d = 0u; d < dim; ++d)
|
||||
v[d] = fetch[d].get(i);
|
||||
if (use_weight) {
|
||||
if (fetch_weight.n >= 0)
|
||||
self.fill(v, v + dim, weight(fetch_weight.get(i)));
|
||||
} else {
|
||||
else if (fetch_count.n >= 0)
|
||||
self.fill(v, v + dim, count(fetch_count.get(i)));
|
||||
else
|
||||
self.fill(v, v + dim);
|
||||
}
|
||||
}
|
||||
|
||||
return python::object();
|
||||
|
||||
Reference in New Issue
Block a user