simpler, restrict numpy arrays to type double

This commit is contained in:
Hans Dembinski
2017-11-03 18:35:01 +01:00
parent b2a2cc123d
commit 7717852168
3 changed files with 61 additions and 69 deletions

View File

@@ -183,41 +183,37 @@ python::object histogram_init(python::tuple args, python::dict kwargs) {
}
struct fetcher {
virtual ~fetcher() {}
virtual double at(long) const = 0;
long n = 0;
};
#ifdef HAVE_NUMPY
struct fetcher_seq : public fetcher {
fetcher_seq(python::object o)
: array(np::from_object(o, np::dtype::get_builtin<double>(), 1)) {
fetcher::n = array.shape(0);
union {
double value = 0;
const double* array;
};
void assign(python::object o) {
// skipping check for currently held type, since it is always value
python::extract<double> get_double(o);
if (get_double.check()) {
value = get_double();
n = 0;
return;
}
~fetcher_seq() {}
double at(long i) const {
return reinterpret_cast<const double*>(array.get_data())[i];
}
np::ndarray array;
};
#endif
struct fetcher_val : public fetcher {
fetcher_val(double val)
: value(val) {}
double at(long) const { return value; }
double value;
};
std::unique_ptr<fetcher> make_fetcher(python::object o) {
python::extract<double> get_double(o);
if (get_double.check())
return std::unique_ptr<fetcher>(new fetcher_val(get_double()));
#ifdef HAVE_NUMPY
return std::unique_ptr<fetcher>(new fetcher_seq(o));
np::ndarray a = python::extract<np::ndarray>(o);
if (a.get_nd() != 1)
throw std::invalid_argument("array must be 1 dimensional");
if (a.get_dtype() != np::dtype::get_builtin<double>())
throw std::invalid_argument("array dtype must be double");
array = reinterpret_cast<const double*>(a.get_data());
n = a.shape(0);
return;
#endif
throw std::invalid_argument("python object is neither sequence nor number");
}
throw std::invalid_argument("argument must be a number");
}
double get(long i) const noexcept {
if (n > 0)
return array[i];
return value;
}
};
python::object histogram_fill(python::tuple args, python::dict kwargs) {
const auto nargs = python::len(args);
@@ -236,35 +232,35 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
python::throw_error_already_set();
}
std::unique_ptr<fetcher> fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
fetcher fetch[BOOST_HISTOGRAM_AXIS_LIMIT];
long n = 0;
for (auto d = 0u; d < dim; ++d) {
fetch[d] = make_fetcher(args[1 + d]);
const auto on = fetch[d]->n;
if (on > 0) {
if (n && on != n) {
fetch[d].assign(args[1 + d]);
if (fetch[d].n > 0) {
if (n && fetch[d].n != n) {
PyErr_SetString(PyExc_ValueError, "lengths of sequences do not match");
python::throw_error_already_set();
}
n = on;
n = fetch[d].n;
}
}
std::unique_ptr<fetcher> fetch_weight;
fetcher fetch_weight;
bool use_weight = false;
const auto nkwargs = python::len(kwargs);
if (nkwargs > 0) {
if (nkwargs > 1 || !kwargs.has_key("weight")) {
PyErr_SetString(PyExc_RuntimeError, "only keyword weight allowed");
python::throw_error_already_set();
}
fetch_weight = make_fetcher(kwargs.get("weight"));
const auto on = fetch_weight->n;
if (on > 0) {
if (n && on != n) {
use_weight = true;
fetch_weight.assign(kwargs.get("weight"));
if (fetch_weight.n > 0) {
if (n && fetch_weight.n != n) {
PyErr_SetString(PyExc_ValueError, "length of weight sequence does not match");
python::throw_error_already_set();
}
n = on;
n = fetch_weight.n;
}
}
@@ -272,9 +268,9 @@ python::object histogram_fill(python::tuple args, python::dict kwargs) {
if (!n) ++n;
for (auto i = 0l; i < n; ++i) {
for (auto d = 0u; d < dim; ++d)
v[d] = fetch[d]->at(i);
if (fetch_weight) {
self.fill(v, v + dim, weight(fetch_weight->at(i)));
v[d] = fetch[d].get(i);
if (use_weight) {
self.fill(v, v + dim, weight(fetch_weight.get(i)));
} else {
self.fill(v, v + dim);
}