a little better

This commit is contained in:
hans.dembinski@gmail.com
2017-03-27 15:35:22 +02:00
parent f38085daa9
commit 10df0cc8e1

View File

@@ -290,27 +290,30 @@ struct storage_access {
template <typename T>
using array = adaptive_storage<>::array<T>;
struct dtype_visitor : public static_visitor<std::pair<unsigned, python::object>> {
struct dtype_visitor : public static_visitor<std::pair<int, python::object>> {
template <typename Array>
std::pair<unsigned, python::object> operator()(const Array& /*unused*/) const {
std::pair<unsigned, python::object> p;
std::pair<int, python::object> operator()(const Array& /*unused*/) const {
std::pair<int, python::object> p;
p.first = sizeof(typename Array::value_type);
p.second = python::str("|u") + python::str(p.first);
return p;
}
std::pair<unsigned, python::object> operator()(const array<void>& /*unused*/) const {
return std::pair<unsigned, python::object>();
std::pair<int, python::object> operator()(const array<void>& /*unused*/) const {
std::pair<int, python::object> p;
p.first = 0; // communicate that the type was array<void>
return p;
}
std::pair<unsigned, python::object> operator()(const array<mp_int>& /*unused*/) const {
std::pair<unsigned, python::object> p;
std::pair<int, python::object> operator()(const array<mp_int>& /*unused*/) const {
std::pair<int, python::object> p;
p.first = sizeof(double);
p.second = python::str("|f") + python::str(p.first);
return p;
}
std::pair<unsigned, python::object> operator()(const array<weight>& /*unused*/) const {
std::pair<unsigned, python::object> p;
std::pair<int, python::object> operator()(const array<weight>& /*unused*/) const {
std::pair<int, python::object> p;
p.first = sizeof(double);
p.second = python::str("|f") + python::str(p.first);
p.first *= -1; // communicate that the type was array<weight>
return p;
}
};
@@ -324,7 +327,7 @@ struct storage_access {
return python::make_tuple(reinterpret_cast<uintptr_t>(b.begin()), false);
}
python::object operator()(const array<void>& /*unused*/) const {
return python::object();
return python::object(); // is never called
}
python::object operator()(const array<mp_int>& b) const {
// cannot pass cpp_int to numpy; make new
@@ -339,24 +342,30 @@ struct storage_access {
PyArray_STRIDES((PyArrayObject *)ptr)[i] = python::extract<npy_intp>(strides[i]);
}
auto *buf = (double *)PyArray_DATA((PyArrayObject *)ptr);
for (int i = 0; i < b.size; ++i)
for (int i = 0; i < b.size; ++i) {
buf[i] = static_cast<double>(b[i]);
}
return python::object(python::handle<>(ptr));
}
};
static python::object array_interface(dynamic_histogram<> &self) {
auto &b = self.storage_.buffer_;
if (auto* pa = get<array<void>>(&b)) {
// buffer not created yet, do that now
b = array<uint8_t>(pa->size);
}
python::dict d;
python::list shapes;
python::list strides;
auto dtype = apply_visitor(dtype_visitor(), b);
auto stride = dtype.first;
if (get<array<weight>>(&b)) {
if (stride == 0) {
// buffer not created yet, do that now
auto a = array<uint8_t>(self.storage_.size());
dtype = dtype_visitor()(a);
b = std::move(a);
stride = dtype.first;
} else
if (stride < 0) {
// buffer is weight, needs special treatment
stride *= -1;
strides.append(stride);
stride *= 2;
shapes.append(2);