mirror of
https://github.com/boostorg/histogram.git
synced 2026-01-30 20:02:13 +00:00
simpler, restrict numpy arrays to type double
This commit is contained in:
@@ -183,41 +183,37 @@ python::object histogram_init(python::tuple args, python::dict kwargs) {
|
||||
}
|
||||
|
||||
struct fetcher {
|
||||
virtual ~fetcher() {}
|
||||
virtual double at(long) const = 0;
|
||||
long n = 0;
|
||||
};
|
||||
|
||||
#ifdef HAVE_NUMPY
|
||||
struct fetcher_seq : public fetcher {
|
||||
fetcher_seq(python::object o)
|
||||
: array(np::from_object(o, np::dtype::get_builtin<double>(), 1)) {
|
||||
fetcher::n = array.shape(0);
|
||||
union {
|
||||
double value = 0;
|
||||
const double* array;
|
||||
};
|
||||
void assign(python::object o) {
|
||||
// skipping check for currently held type, since it is always value
|
||||
python::extract<double> get_double(o);
|
||||
if (get_double.check()) {
|
||||
value = get_double();
|
||||
n = 0;
|
||||
return;
|
||||
}
|
||||
~fetcher_seq() {}
|
||||
double at(long i) const {
|
||||
return reinterpret_cast<const double*>(array.get_data())[i];
|
||||
}
|
||||
np::ndarray array;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct fetcher_val : public fetcher {
|
||||
fetcher_val(double val)
|
||||
: value(val) {}
|
||||
double at(long) const { return value; }
|
||||
double value;
|
||||
};
|
||||
|
||||
std::unique_ptr<fetcher> make_fetcher(python::object o) {
|
||||
python::extract<double> get_double(o);
|
||||
if (get_double.check())
|
||||
return std::unique_ptr<fetcher>(new fetcher_val(get_double()));
|
||||
#ifdef HAVE_NUMPY
|
||||
return std::unique_ptr<fetcher>(new fetcher_seq(o));
|
||||
np::ndarray a = python::extract<np::ndarray>(o);
|
||||
if (a.get_nd() != 1)
|
||||
throw std::invalid_argument("array must be 1 dimensional");
|
||||
if (a.get_dtype() != np::dtype::get_builtin<double>())
|
||||
throw std::invalid_argument("array dtype must be double");
|
||||
array = reinterpret_cast<const double*>(a.get_data());
|
||||
n = a.shape(0);
|
||||
return;
|
||||
#endif
|
||||
throw std::invalid_argument("python object is neither sequence nor number");
|
||||
}
|
||||
throw std::invalid_argument("argument must be a number");
|
||||
}
|
||||
double get(long i) const noexcept {
|
||||
if (n > 0)
|
||||
return array[i];
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
const auto nargs = python::len(args);
|
||||
@@ -236,35 +232,35 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
|
||||
std::unique_ptr<fetcher> fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
|
||||
fetcher fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
|
||||
long n = 0;
|
||||
for (auto d = 0u; d < dim; ++d) {
|
||||
fetch[d] = make_fetcher(args[1 + d]);
|
||||
const auto on = fetch[d]->n;
|
||||
if (on > 0) {
|
||||
if (n && on != n) {
|
||||
fetch[d].assign(args[1 + d]);
|
||||
if (fetch[d].n > 0) {
|
||||
if (n && fetch[d].n != n) {
|
||||
PyErr_SetString(PyExc_ValueError, "lengths of sequences do not match");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
n = on;
|
||||
n = fetch[d].n;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<fetcher> fetch_weight;
|
||||
fetcher fetch_weight;
|
||||
bool use_weight = false;
|
||||
const auto nkwargs = python::len(kwargs);
|
||||
if (nkwargs > 0) {
|
||||
if (nkwargs > 1 || !kwargs.has_key("weight")) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "only keyword weight allowed");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
fetch_weight = make_fetcher(kwargs.get("weight"));
|
||||
const auto on = fetch_weight->n;
|
||||
if (on > 0) {
|
||||
if (n && on != n) {
|
||||
use_weight = true;
|
||||
fetch_weight.assign(kwargs.get("weight"));
|
||||
if (fetch_weight.n > 0) {
|
||||
if (n && fetch_weight.n != n) {
|
||||
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
|
||||
python::throw_error_already_set();
|
||||
}
|
||||
n = on;
|
||||
n = fetch_weight.n;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,9 +268,9 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
|
||||
if (!n) ++n;
|
||||
for (auto i = 0l; i < n; ++i) {
|
||||
for (auto d = 0u; d < dim; ++d)
|
||||
v[d] = fetch[d]->at(i);
|
||||
if (fetch_weight) {
|
||||
self.fill(v, v + dim, weight(fetch_weight->at(i)));
|
||||
v[d] = fetch[d].get(i);
|
||||
if (use_weight) {
|
||||
self.fill(v, v + dim, weight(fetch_weight.get(i)));
|
||||
} else {
|
||||
self.fill(v, v + dim);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user