This commit is contained in:
Hans Dembinski
2018-11-01 23:44:17 +01:00
parent 107fdd6c0f
commit cb85d65489
3 changed files with 32 additions and 62 deletions

View File

@@ -156,8 +156,8 @@ public:
[&x](const auto& a) {
using A = detail::unqual<decltype(a)>;
using expected_t = axis::traits::arg<A>;
return detail::static_if<std::is_convertible<detail::unqual<U>, expected_t>>(
[&x](const auto& a) -> int { return a(static_cast<expected_t>(x)); },
return detail::static_if<std::is_convertible<U, expected_t>>(
[&x](const auto& a) -> int { return a(x); },
[](const auto&) -> int {
throw std::invalid_argument(detail::cat(
"cannot convert ", boost::core::demangled_name(BOOST_CORE_TYPEID(U)),

View File

@@ -305,47 +305,12 @@ inline void linearize(optional_index& out, const int axis_size, const int axis_s
out.stride *= (j < axis_shape) * axis_shape; // set to 0, if j is invalid
}
template <std::size_t D, typename Axes>
void indices_to_index(optional_index&, const Axes&) noexcept {}
template <std::size_t D, typename Axes, typename... Us>
void indices_to_index(optional_index& idx, const Axes& axes, const int j,
const Us... us) {
const auto& a = axis_get<D>(axes);
const auto a_size = static_cast<int>(a.size());
const auto a_shape = static_cast<int>(axis::traits::extend(a));
idx.stride *= (-1 <= j && j <= a_size); // set to 0, if j is invalid
linearize(idx, a_size, a_shape, j);
indices_to_index<(D + 1)>(idx, axes, us...);
}
template <typename... Ts, typename Iterator>
void indices_to_index_iter(mp11::mp_size_t<0>, optional_index&, const std::tuple<Ts...>&,
Iterator) {}
template <std::size_t N, typename... Ts, typename Iterator>
void indices_to_index_iter(mp11::mp_size_t<N>, optional_index& idx,
const std::tuple<Ts...>& axes, Iterator iter) {
constexpr auto D = mp11::mp_size_t<sizeof...(Ts)>() - N;
const auto& a = std::get<D>(axes);
const auto a_size = static_cast<int>(a.size());
const auto a_shape = axis::traits::extend(a);
const auto j = static_cast<int>(*iter);
idx.stride *= (-1 <= j && j <= a_size); // set to 0, if j is invalid
linearize(idx, a_size, a_shape, j);
indices_to_index_iter(mp11::mp_size_t<(N - 1)>(), idx, axes, ++iter);
}
template <typename... Ts, typename Iterator>
void indices_to_index_iter(optional_index& idx, const std::vector<Ts...>& axes,
Iterator iter) {
for (const auto& a : axes) {
const auto a_size = static_cast<int>(a.size());
const auto a_shape = axis::traits::extend(a);
const auto j = static_cast<int>(*iter++);
idx.stride *= (-1 <= j && j <= a_size); // set to 0, if j is invalid
linearize(idx, a_size, a_shape, j);
}
template <typename T>
void linearize2(optional_index& out, const T& axis, int j) {
const auto a_size = static_cast<int>(axis.size());
const auto a_shape = axis::traits::extend(axis);
out.stride *= (-1 <= j && j <= a_size); // set to 0, if j is invalid
linearize(out, a_size, a_shape, j);
}
template <std::size_t D, typename... Ts>
@@ -501,37 +466,40 @@ optional_index call_impl(iterable_container_tag, const std::vector<Ts...>& axes,
* the exception and do something sensible.
*/
template <typename A, typename... Us>
optional_index at_impl(no_container_tag, const A& axes, const Us&... us) {
dimension_check(axes, mp11::mp_size_t<sizeof...(Us)>());
auto index = detail::optional_index();
detail::indices_to_index<0>(index, axes, static_cast<int>(us)...);
return index;
template <typename A, typename U>
optional_index at_impl(no_container_tag, const A& axes, const U& u) {
return at_impl(static_container_tag(), axes, std::forward_as_tuple(u));
}
template <typename A, typename U>
optional_index at_impl(static_container_tag, const A& axes, const U& u) {
return mp11::tuple_apply(
[&axes](const auto&... us) { return at_impl(no_container_tag(), axes, us...); }, u);
dimension_check(axes, mp11::mp_size<unqual<U>>());
detail::optional_index idx;
mp11::mp_for_each<mp11::mp_iota<mp_size<U>>>([&](auto I) {
linearize2(idx, axis_get<decltype(I)::value>(axes), static_cast<int>(std::get<I>(u)));
});
return idx;
}
template <typename... Ts, typename U>
optional_index at_impl(iterable_container_tag, const std::tuple<Ts...>& axes,
const U& u) {
dimension_check(axes, std::distance(std::begin(u), std::end(u)));
auto index = detail::optional_index();
detail::indices_to_index_iter(mp11::mp_size_t<sizeof...(Ts)>(), index, axes,
std::begin(u));
return index;
detail::optional_index idx;
auto xit = std::begin(u);
mp11::mp_for_each<mp11::mp_iota_c<sizeof...(Ts)>>(
[&](auto I) { linearize2(idx, std::get<I>(axes), static_cast<int>(*xit++)); });
return idx;
}
template <typename... Ts, typename U>
optional_index at_impl(iterable_container_tag, const std::vector<Ts...>& axes,
const U& u) {
dimension_check(axes, std::distance(std::begin(u), std::end(u)));
auto index = detail::optional_index();
detail::indices_to_index_iter(index, axes, std::begin(u));
return index;
detail::optional_index idx;
auto xit = std::begin(u);
for (const auto& a : axes) linearize2(idx, a, static_cast<int>(*xit++));
return idx;
}
} // namespace detail

View File

@@ -148,7 +148,8 @@ public:
template <typename... Ts>
void operator()(const Ts&... ts) {
// case with one argument needs special treatment, specialized below
const auto index = detail::call_impl(detail::no_container_tag(), axes_, ts...);
const auto index = detail::call_impl(detail::static_container_tag(), axes_,
std::forward_as_tuple(ts...));
if (index) storage_(*index);
}
@@ -163,7 +164,8 @@ public:
template <typename U, typename... Ts>
void operator()(const weight_type<U>& w, const Ts&... ts) {
// case with one argument needs special treatment, specialized below
const auto index = detail::call_impl(detail::no_container_tag(), axes_, ts...);
const auto index = detail::call_impl(detail::static_container_tag(), axes_,
std::forward_as_tuple(ts...));
if (index) storage_(*index, w);
}
@@ -178,8 +180,8 @@ public:
template <typename... Ts>
const_reference at(const Ts&... ts) const {
// case with one argument is ambiguous, is specialized below
const auto index =
detail::at_impl(detail::no_container_tag(), axes_, static_cast<int>(ts)...);
const auto index = detail::at_impl(detail::static_container_tag(), axes_,
std::forward_as_tuple(static_cast<int>(ts)...));
if (!index) throw std::out_of_range("indices out of bounds");
return storage_[*index];
}