2
0
mirror of https://github.com/boostorg/compute.git synced 2026-01-26 06:22:37 +00:00
Files
compute/test/test_blas_gemm.cpp
2013-06-24 09:57:22 +02:00

230 lines
6.5 KiB
C++

//---------------------------------------------------------------------------//
// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
//
// Distributed under the Boost Software License, Version 1.0
// See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt
//
// See http://kylelutz.github.com/compute for more information.
//---------------------------------------------------------------------------//
#define BOOST_TEST_MODULE TestBlasGemm
#include <boost/test/unit_test.hpp>
#include <boost/compute/malloc.hpp>
#include <boost/compute/algorithm/copy.hpp>
#include <boost/compute/algorithm/fill.hpp>
#include <boost/compute/blas/gemm.hpp>
#include "context_setup.hpp"
BOOST_AUTO_TEST_CASE(gemm_float3x3)
{
float a[9] = { 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f,
7.0f, 8.0f, 9.0f };
float b[9] = { 10.0f, 12.0f, 14.0f,
24.0f, 25.0f, 26.0f,
37.0f, 38.0f, 39.0f };
float c[9] = { 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f };
boost::compute::device_ptr<float> A =
boost::compute::malloc<float>(9, context);
boost::compute::device_ptr<float> B =
boost::compute::malloc<float>(9, context);
boost::compute::device_ptr<float> C =
boost::compute::malloc<float>(9, context);
boost::compute::copy(a, a + 9, A, queue);
boost::compute::copy(b, b + 9, B, queue);
boost::compute::copy(c, c + 9, C, queue);
// C = A * A
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
3, 3, 3,
1.0f,
A, 3,
A, 3,
0.0f,
C, 3,
queue
);
boost::compute::copy(C, C + 9, c, queue);
BOOST_CHECK_CLOSE(c[0], 30.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 36.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 42.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 66.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[4], 81.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[5], 96.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[6], 102.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[7], 126.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[8], 150.0f, 1e-4f);
// C = A * B
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
3, 3, 3,
1.0f,
A, 3,
B, 3,
0.0f,
C, 3,
queue
);
boost::compute::copy(C, C + 9, c, queue);
BOOST_CHECK_CLOSE(c[0], 169.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 176.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 183.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 382.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[4], 401.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[5], 420.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[6], 595.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[7], 626.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[8], 657.0f, 1e-4f);
}
BOOST_AUTO_TEST_CASE(gemm_float2x3)
{
float a[6] = { 2.0f, 1.0f,
4.0f, 3.0f,
6.0f, 5.0f };
float b[6] = { 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f };
float c[9] = { 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f };
boost::compute::device_ptr<float> A =
boost::compute::malloc<float>(6, context);
boost::compute::device_ptr<float> B =
boost::compute::malloc<float>(6, context);
boost::compute::device_ptr<float> C =
boost::compute::malloc<float>(9, context);
boost::compute::copy(a, a + 6, A, queue);
boost::compute::copy(b, b + 6, B, queue);
boost::compute::copy(c, c + 9, C, queue);
// C = A * B
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
3, 3, 2,
1.0f,
A, 2,
B, 3,
0.0f,
C, 3,
queue
);
boost::compute::copy(C, C + 9, c, queue);
BOOST_CHECK_CLOSE(c[0], 6.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 9.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 12.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 16.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[4], 23.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[5], 30.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[6], 26.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[7], 37.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[8], 48.0f, 1e-4f);
// C = B * A
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
2, 2, 3,
1.0f,
B, 3,
A, 2,
0.0f,
C, 2,
queue
);
boost::compute::copy(C, C + 4, c, queue);
BOOST_CHECK_CLOSE(c[0], 28.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 22.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 64.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 49.0f, 1e-4f);
// C = B * A (with alpha = 2)
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
2, 2, 3,
2.0f,
B, 3,
A, 2,
0.0f,
C, 2,
queue
);
boost::compute::copy(C, C + 4, c, queue);
BOOST_CHECK_CLOSE(c[0], 56.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 44.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 128.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 98.0f, 1e-4f);
// fill C with 4's
boost::compute::fill(C, C + 4, 4.0f, queue);
// C = B * A (with beta = 3)
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
2, 2, 3,
1.0f,
B, 3,
A, 2,
3.0f,
C, 2,
queue
);
boost::compute::copy(C, C + 4, c, queue);
BOOST_CHECK_CLOSE(c[0], 40.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 34.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 76.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 61.0f, 1e-4f);
// fill C with 3's
boost::compute::fill(C, C + 4, 3.0f, queue);
// C = B * A (with alpha = 3, beta = 2)
boost::compute::blas::gemm(
boost::compute::blas::row_major,
boost::compute::blas::no_transpose,
boost::compute::blas::no_transpose,
2, 2, 3,
3.0f,
B, 3,
A, 2,
2.0f,
C, 2,
queue
);
boost::compute::copy(C, C + 4, c, queue);
BOOST_CHECK_CLOSE(c[0], 90.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[1], 72.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[2], 198.0f, 1e-4f);
BOOST_CHECK_CLOSE(c[3], 153.0f, 1e-4f);
}
BOOST_AUTO_TEST_SUITE_END()