2
0
mirror of https://github.com/boostorg/hana.git synced 2026-02-17 13:52:10 +00:00
Files
hana/test/sandbox/matrix.cpp
2014-06-23 15:37:29 -04:00

352 lines
9.8 KiB
C++

/*
@copyright Louis Dionne 2014
Distributed under the Boost Software License, Version 1.0.
(See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt)
*/
#include <boost/hana/comparable.hpp>
#include <boost/hana/core.hpp>
#include <boost/hana/detail/static_assert.hpp>
#include <boost/hana/functional.hpp>
#include <boost/hana/functor.hpp>
#include <boost/hana/integral.hpp>
#include <boost/hana/list.hpp>
#include <boost/hana/range.hpp>
#include <boost/hana/type.hpp>
using namespace boost::hana;
struct Matrix;
template <typename Storage, typename = operators::enable>
struct matrix_type {
using hana_datatype = Matrix;
Storage rows_;
constexpr auto ncolumns() const
{ return length(head(rows_)); }
constexpr auto nrows() const
{ return length(rows_); }
constexpr auto size() const
{ return nrows() * ncolumns(); }
template <typename I, typename J>
constexpr auto at(I i, J j) const
{ return boost::hana::at(j, boost::hana::at(i, rows_)); }
};
auto transpose = [](auto m) {
auto new_storage = unpack(zip, m.rows_);
return matrix_type<decltype(new_storage)>{new_storage};
};
auto rows = [](auto m) {
return m.rows_;
};
auto columns = [](auto m) {
return rows(transpose(m));
};
auto elementwise = [](auto f) {
return [=](auto ...matrices) {
auto new_storage = zip_with(partial(zip_with, f), matrices.rows_...);
return matrix_type<decltype(new_storage)>{new_storage};
};
};
template <typename S1, typename S2>
constexpr auto operator+(matrix_type<S1> m1, matrix_type<S2> m2)
{ return elementwise(_+_)(m1, m2); }
template <typename S1, typename S2>
constexpr auto operator-(matrix_type<S1> m1, matrix_type<S2> m2)
{ return elementwise(_-_)(m1, m2); }
auto scalar_prod = [](auto v1, auto v2) {
return sum(zip_with(_*_, v1, v2));
};
auto repeat_n = [](auto n, auto x) {
return unpack(on(list, always(x)), range(int_<0>, n));
};
template <typename S1, typename S2>
constexpr auto operator*(matrix_type<S1> m1, matrix_type<S2> m2) {
auto storage = fmap(
[=](auto row) {
return zip_with(
scalar_prod,
repeat_n(m2.ncolumns(), row),
columns(m2));
}
, rows(m1));
return matrix_type<decltype(storage)>{storage};
}
auto row = [](auto ...entries) {
return list(entries...);
};
auto matrix = [](auto ...rows) {
auto storage = list(rows...);
auto all_same_length = all([=](auto row) {
return length(row) == length(head(storage));
}, tail(storage));
static_assert(all_same_length,
"");
return matrix_type<decltype(storage)>{storage};
};
auto vector = on(matrix, row);
constexpr int exponent(int x, unsigned int n) {
if (n == 0) return 1;
else return x * exponent(x, n - 1);
}
auto remove_at = [](auto n, auto xs) {
auto with_indices = zip(xs, range(int_<0>, length(xs)));
auto removed = filter(compose(n != _, last), with_indices);
return fmap(head, removed);
};
template <typename Matrix>
struct _det; // remove circular dependency between matrix_minor and det
auto matrix_minor = [](auto m, auto i, auto j) {
auto submatrix_storage = fmap(partial(remove_at, j), remove_at(i, rows(m)));
matrix_type<decltype(submatrix_storage)> submatrix{submatrix_storage};
return _det<decltype(submatrix)>{}(submatrix);
};
auto cofactor = [](auto m, auto i, auto j) {
auto i_plus_j = i + j;
return int_<exponent(-1, i_plus_j())> * matrix_minor(m, i, j);
};
template <typename Matrix>
struct _det {
constexpr auto operator()(Matrix m) const {
return eval_if(m.size() == int_<1>,
always(m.at(int_<0>, int_<0>)),
[=](auto _) {
auto cofactors_1st_row = unpack(
on(list, partial(cofactor, m, int_<0>)),
_(range)(int_<0>, m.ncolumns()));
return scalar_prod(head(rows(m)), cofactors_1st_row);
}
);
}
};
auto det = [](auto m) {
return _det<decltype(m)>{}(m);
};
namespace boost { namespace hana {
template <>
struct Functor::instance<Matrix> : Functor::fmap_mcd {
template <typename F, typename M>
static constexpr auto fmap_impl(F f, M mat) {
auto new_storage = fmap(partial(fmap, f), mat.rows_);
return matrix_type<decltype(new_storage)>{new_storage};
}
};
template <>
struct Comparable::instance<Matrix, Matrix> : Comparable::equal_mcd {
template <typename M1, typename M2>
static constexpr auto equal_impl(M1 m1, M2 m2) {
return m1.nrows() == m2.nrows() &&
m1.ncolumns() == m2.ncolumns() &&
all_of(zip_with(_==_, m1.rows_, m2.rows_));
}
};
}}
void test_sizes() {
auto m = matrix(
row(1, '2', 3),
row('4', char_<'5'>, 6)
);
BOOST_HANA_STATIC_ASSERT(m.size() == 6);
BOOST_HANA_STATIC_ASSERT(m.ncolumns() == 3);
BOOST_HANA_STATIC_ASSERT(m.nrows() == 2);
}
void test_at() {
auto m = matrix(
row(1, '2', 3),
row('4', char_<'5'>, 6),
row(int_<7>, '8', 9.3)
);
BOOST_HANA_STATIC_ASSERT(m.at(int_<0>, int_<0>) == 1);
BOOST_HANA_STATIC_ASSERT(m.at(int_<0>, int_<1>) == '2');
BOOST_HANA_STATIC_ASSERT(m.at(int_<0>, int_<2>) == 3);
BOOST_HANA_STATIC_ASSERT(m.at(int_<1>, int_<0>) == '4');
BOOST_HANA_STATIC_ASSERT(m.at(int_<1>, int_<1>) == char_<'5'>);
BOOST_HANA_STATIC_ASSERT(m.at(int_<1>, int_<2>) == 6);
BOOST_HANA_STATIC_ASSERT(m.at(int_<2>, int_<0>) == int_<7>);
BOOST_HANA_STATIC_ASSERT(m.at(int_<2>, int_<1>) == '8');
BOOST_HANA_STATIC_ASSERT(m.at(int_<2>, int_<2>) == 9.3);
}
void test_comparable() {
BOOST_HANA_STATIC_ASSERT(matrix(row(1, 2)) == matrix(row(1, 2)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1, 2)) != matrix(row(1, 5)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1, 2), row(3, 4)) == matrix(row(1, 2), row(3, 4)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1, 2), row(3, 4)) != matrix(row(1, 2), row(0, 4)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1, 2), row(3, 4)) != matrix(row(0, 2), row(3, 4)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1), row(2)) != matrix(row(3, 4), row(5, 6)));
BOOST_HANA_STATIC_ASSERT(matrix(row(1), row(2)) != matrix(row(3, 4)));
}
void test_functor() {
auto m = matrix(
row(1, int_<2>, 3),
row(int_<4>, 5, 6),
row(7, 8, int_<9>)
);
BOOST_HANA_STATIC_ASSERT(fmap(_ + int_<1>, m) ==
matrix(
row(2, int_<3>, 4),
row(int_<5>, 6, 7),
row(8, 9, int_<10>)
)
);
}
void test_operators() {
auto m = matrix(row(1, 2), row(3, 4));
BOOST_HANA_STATIC_ASSERT(m + m == matrix(row(2, 4), row(6, 8)));
BOOST_HANA_STATIC_ASSERT(m - m == matrix(row(0, 0), row(0, 0)));
}
void test_matrix_multiplication() {
auto A = matrix(
row(1, 2, 3),
row(4, 5, 6)
);
auto B = matrix(
row(1, 2),
row(3, 4),
row(5, 6)
);
BOOST_HANA_STATIC_ASSERT(A * B == matrix(
row(1*1 + 2*3 + 5*3, 1*2 + 2*4 + 3*6),
row(4*1 + 3*5 + 5*6, 4*2 + 5*4 + 6*6)
));
}
void test_vector() {
auto v = vector(1, '2', int_<3>, 4.2f);
BOOST_HANA_STATIC_ASSERT(v.size() == 4);
BOOST_HANA_STATIC_ASSERT(v.nrows() == 4);
BOOST_HANA_STATIC_ASSERT(v.ncolumns() == 1);
}
void test_transpose() {
auto m = matrix(
row(1, 2.2, '3'),
row(4, '5', 6)
);
BOOST_HANA_STATIC_ASSERT(transpose(m) ==
matrix(
row(1, 4),
row(2.2, '5'),
row('3', 6)
)
);
}
void test_repeat_n() {
struct T;
BOOST_HANA_STATIC_ASSERT(repeat_n(int_<0>, type<T>) == list());
BOOST_HANA_STATIC_ASSERT(repeat_n(int_<1>, type<T>) == list_t<T>);
BOOST_HANA_STATIC_ASSERT(repeat_n(int_<2>, type<T>) == list_t<T, T>);
BOOST_HANA_STATIC_ASSERT(repeat_n(int_<3>, type<T>) == list_t<T, T, T>);
BOOST_HANA_STATIC_ASSERT(repeat_n(int_<4>, type<T>) == list_t<T, T, T, T>);
}
void test_determinant() {
BOOST_HANA_STATIC_ASSERT(det(matrix(row(1))) == 1);
BOOST_HANA_STATIC_ASSERT(det(matrix(row(2))) == 2);
BOOST_HANA_STATIC_ASSERT(det(matrix(row(1, 2), row(3, 4))) == -2);
BOOST_HANA_STATIC_ASSERT(
det(matrix(
row(1, 5, 6),
row(3, 2, 4),
row(7, 8, 9)
))
== 51
);
BOOST_HANA_STATIC_ASSERT(
det(matrix(
row(1, 5, 6, -3),
row(3, 2, 4, -5),
row(7, 8, 9, -1),
row(8, 2, 1, 10)
)) == 214
);
BOOST_HANA_STATIC_ASSERT(
det(matrix(
row(1, 5, 6, -3, 92),
row(3, 2, 4, -5, 13),
row(7, 8, 9, -1, 0),
row(8, 2, 1, 10, 41),
row(3, 12, 92, -7, -4)
)) == -3115014
);
}
void test_remove_at() {
BOOST_HANA_STATIC_ASSERT(remove_at(int_<0>, list(1)) == list());
BOOST_HANA_STATIC_ASSERT(remove_at(int_<0>, list(1, '2')) == list('2'));
BOOST_HANA_STATIC_ASSERT(remove_at(int_<0>, list(1, '2', 3.3)) == list('2', 3.3));
BOOST_HANA_STATIC_ASSERT(remove_at(int_<1>, list(1, '2')) == list(1));
BOOST_HANA_STATIC_ASSERT(remove_at(int_<1>, list(1, '2', 3.3)) == list(1, 3.3));
BOOST_HANA_STATIC_ASSERT(remove_at(int_<2>, list(1, '2', 3.3)) == list(1, '2'));
}
void test_exponent() {
BOOST_HANA_STATIC_ASSERT(exponent(3, 0) == 1);
BOOST_HANA_STATIC_ASSERT(exponent(3, 1) == 3);
BOOST_HANA_STATIC_ASSERT(exponent(3, 2) == 3 * 3);
BOOST_HANA_STATIC_ASSERT(exponent(3, 3) == 3 * 3 * 3);
BOOST_HANA_STATIC_ASSERT(exponent(3, 4) == 3 * 3 * 3 * 3);
}
int main() {
test_repeat_n();
test_remove_at();
test_exponent();
test_sizes();
test_at();
test_comparable();
test_functor();
test_operators();
test_vector();
test_transpose();
test_matrix_multiplication();
test_determinant();
}