From c1612f683891c42ad086607120d8c634e40193bf Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Mon, 28 Aug 2023 14:09:10 +0530 Subject: [PATCH] Gtestsuite Framework and Unit Tests for Pack and Compute Extension APIs - Added framework for unit testing of BLAS and CBLAS interfaces for the Pack and Compute Extension APIs. - These test the integrated functionality of the trio of ?gemm_pack_get_size(), ?gemm_pack() and ?gemm_compute() APIs. - Note: Only MKL can be used as reference for now. AMD-Internal: [CPUPL-3560] Change-Id: I801654447a716da06c9ccf9db01d553817871571 --- .../inc/level3/ref_gemm_compute.h | 69 +++ .../src/level3/ref_gemm_compute.cpp | 200 ++++++++ .../testsuite/level3/gemm/dgemm_generic.cpp | 2 +- .../gemm_compute/dgemm_compute_generic.cpp | 187 +++++++ .../level3/gemm_compute/gemm_compute.h | 456 ++++++++++++++++++ .../gemm_compute/gemm_compute_IIT_ERS.cpp | 222 +++++++++ .../gemm_compute/sgemm_compute_generic.cpp | 189 ++++++++ .../level3/gemm_compute/test_gemm_compute.h | 79 +++ 8 files changed, 1403 insertions(+), 1 deletion(-) create mode 100644 gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h create mode 100644 gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp create mode 100644 gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp create mode 100644 gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h create mode 100644 gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp create mode 100644 gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp create mode 100644 gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h diff --git a/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h b/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h new file mode 100644 index 0000000000..283a2b06ec --- /dev/null +++ b/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h @@ -0,0 +1,69 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * GEMM Compute performs one of the matrix-matrix operations + * C := op( A )*op( B ) + beta*C, + * where op( A ) is one of + * op( A ) = alpha * A or op( A ) = alpha * A**T + * op( A ) = A or op( A ) = A**T + * op( B ) is one of + * op( B ) = alpha * B or op( B ) = alpha * B**T + * op( B ) = B or op( B ) = B**T + * alpha and beta are scalars, and A, B and C are matrices, with op( A ) + * an m by k matrix, op( B ) a k by n matrix and C an m by n matrix, + * where either op( A ) or op( B ) or both may be reordered. + ========================================================================== +*/ + +namespace testinghelpers { + +template +void ref_gemm_compute ( + char storage, char trnsa, char trnsb, + char pcka, char pckb, + gtint_t m, gtint_t n, gtint_t k, + T alpha, + T* ap, gtint_t lda, + T* bp, gtint_t ldb, + T beta, + T* cp, gtint_t ldc +); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp new file mode 100644 index 0000000000..2b15ffea2b --- /dev/null +++ b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp @@ -0,0 +1,200 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include +#include "level3/ref_gemm_compute.h" + +/* + * ========================================================================== + * GEMM Pack and Compute Extension performs the GEMM matrix-matrix operations + * by first packing/reordering A/B matrix and computing the GEMM operation + * on the packed buffer. + * + * Pack: + * Reorders the A or B matrix or both the matrices and scales them with + * alpha. + * + * Compute: + * C := A * B + beta*C, + * where, + * Either A or B or both A and B matrices are packed matrices. + * Alpha and beta are scalars, and A, B and C are matrices, with A + * an m by k matrix, B a k by n matrix and C an m by n matrix, + * where either A or B or both may be scaled by alpha and reordered. + * ========================================================================== + */ + +namespace testinghelpers { + +template +void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc) +{ + T unit_alpha = 1.0; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trnsa, &cblas_transa ); + char_to_cblas_trans( trnsb, &cblas_transb ); + + using scalar_t = std::conditional_t::is_complex, T&, T>; + + typedef gint_t (*Fptr_ref_cblas_gemm_pack_get_size)( const CBLAS_IDENTIFIER, + const f77_int, const f77_int, const f77_int ); + Fptr_ref_cblas_gemm_pack_get_size ref_cblas_gemm_pack_get_size; + + typedef void (*Fptr_ref_cblas_gemm_pack)( const CBLAS_ORDER, const CBLAS_IDENTIFIER, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const f77_int, const T, const T*, f77_int, + T*); + Fptr_ref_cblas_gemm_pack ref_cblas_gemm_pack; + + typedef void (*Fptr_ref_cblas_gemm_compute)( const CBLAS_ORDER, const f77_int, const f77_int, + const f77_int, const f77_int, const f77_int, const T*, f77_int, + const T*, f77_int, const scalar_t, T*, f77_int); + Fptr_ref_cblas_gemm_compute ref_cblas_gemm_compute; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_gemm_pack_get_size = (Fptr_ref_cblas_gemm_pack_get_size)refCBLASModule.loadSymbol("cblas_sgemm_pack_get_size"); + ref_cblas_gemm_pack = (Fptr_ref_cblas_gemm_pack)refCBLASModule.loadSymbol("cblas_sgemm_pack"); + ref_cblas_gemm_compute = (Fptr_ref_cblas_gemm_compute)refCBLASModule.loadSymbol("cblas_sgemm_compute"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_gemm_pack_get_size = (Fptr_ref_cblas_gemm_pack_get_size)refCBLASModule.loadSymbol("cblas_dgemm_pack_get_size"); + ref_cblas_gemm_pack = (Fptr_ref_cblas_gemm_pack)refCBLASModule.loadSymbol("cblas_dgemm_pack"); + ref_cblas_gemm_compute = (Fptr_ref_cblas_gemm_compute)refCBLASModule.loadSymbol("cblas_dgemm_compute"); + } + else + { + throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template."); + } + if( !ref_cblas_gemm_compute ) { + throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found."); + } + + err_t err = BLIS_SUCCESS; + + if ( ( pcka == 'P' || pcka == 'p' ) && ( pckb == 'P' || pckb == 'p' ) ) + { + // Reorder A + CBLAS_IDENTIFIER cblas_identifierA = CblasAMatrix; + CBLAS_STORAGE cblas_packed = CblasPacked; + gtint_t bufSizeA = ref_cblas_gemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + ref_cblas_gemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, alpha, ap, lda, aBuffer ); + + // Reorder B + CBLAS_IDENTIFIER cblas_identifierB = CblasBMatrix; + gtint_t bufSizeB = ref_cblas_gemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + ref_cblas_gemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, unit_alpha, bp, ldb, bBuffer ); + + ref_cblas_gemm_compute( cblas_order, cblas_packed, cblas_packed, + m, n, k, aBuffer, lda, bBuffer, ldb, beta, cp, ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( pcka == 'P' || pcka == 'p' ) ) + { + // Reorder A + CBLAS_IDENTIFIER cblas_identifier = CblasAMatrix; + CBLAS_STORAGE cblas_packed = CblasPacked; + gtint_t bufSizeA = ref_cblas_gemm_pack_get_size( cblas_identifier, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + ref_cblas_gemm_pack( cblas_order, cblas_identifier, cblas_transa, + m, n, k, alpha, ap, lda, aBuffer ); + + ref_cblas_gemm_compute( cblas_order, cblas_packed, cblas_transb, + m, n, k, aBuffer, lda, bp, ldb, beta, cp, ldc ); + + bli_free_user( aBuffer ); + } + else if ( ( pckb == 'P' || pckb == 'p' ) ) + { + // Reorder B + CBLAS_IDENTIFIER cblas_identifier = CblasBMatrix; + CBLAS_STORAGE cblas_packed = CblasPacked; + gtint_t bufSizeB = ref_cblas_gemm_pack_get_size( cblas_identifier, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + ref_cblas_gemm_pack( cblas_order, cblas_identifier, cblas_transb, + m, n, k, alpha, bp, ldb, bBuffer ); + + ref_cblas_gemm_compute( cblas_order, cblas_transa, cblas_packed, + m, n, k, ap, lda, bBuffer, ldb, beta, cp, ldc ); + + bli_free_user( bBuffer ); + } + else + { + ref_cblas_gemm_compute( cblas_order, cblas_transa, cblas_transb, + m, n, k, ap, lda, bp, ldb, beta, cp, ldc ); + } +} + +// Explicit template instantiations +template void ref_gemm_compute(char, char, char, char, char, gtint_t, gtint_t, gtint_t, float, + float*, gtint_t, float*, gtint_t, float, float*, gtint_t ); +template void ref_gemm_compute(char, char, char, char, char, gtint_t, gtint_t, gtint_t, double, + double*, gtint_t, double*, gtint_t, double, double*, gtint_t ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp index b74f63aea2..8d07668cc4 100644 --- a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp @@ -79,7 +79,7 @@ TEST_P(DGemmTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 10*m*n*k*testinghelpers::getEpsilon(); + double thresh = 10*m*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters diff --git a/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp new file mode 100644 index 0000000000..82b89b7191 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp @@ -0,0 +1,187 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" + +class DGemmComputeTest : + public ::testing::TestWithParam> {}; + +TEST_P(DGemmComputeTest, RandomData) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t + char transb = std::get<2>(GetParam()); + // denotes whether matrix a is packed (p) or unpacked (u) + char packa = std::get<3>(GetParam()); + // denotes whether matrix b is packed (p) or unpacked (u) + char packb = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // matrix size k + gtint_t k = std::get<7>(GetParam()); + // specifies alpha value + T alpha = std::get<8>(GetParam()); + // specifies beta value + T beta = std::get<9>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<10>(GetParam()); + gtint_t ldb_inc = std::get<11>(GetParam()); + gtint_t ldc_inc = std::get<12>(GetParam()); + + // Set the threshold for the errors: + double intermediate = (double)m*n*k; + double thresh = 10*intermediate*testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +class DGemmComputeTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + char pka = std::get<3>(str.param); + char pkb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + double alpha = std::get<8>(str.param); + double beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); +#ifdef TEST_BLAS + std::string str_name = "dgemm_compute_"; +#elif TEST_CBLAS + std::string str_name = "cblas_dgemm_compute"; +#else //#elif TEST_BLIS_TYPED + // BLIS interface not yet implemented for pack and compute APIs. + std::string str_name = "blis_dgemm_compute"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + pka + pkb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); + str_name = str_name + "_a" + alpha_str; + std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); + str_name = str_name + "_b" + beta_str; + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + DGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::DGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes + DGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Values(71, 73), // m (MC - 1, MC + 1) + ::testing::Values(4079, 4081), // n (NC - 1, NC + 1) + ::testing::Values(255, 257), // k (KC - 1, KC + 1) + ::testing::Values(1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::DGemmComputeTestPrint() + ); \ No newline at end of file diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h new file mode 100644 index 0000000000..b57691dfe3 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h @@ -0,0 +1,456 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" + +/** + * @brief Performs the operation: + * C := op( A )*op( B ) + beta*C, + * where op( A ) is one of + * op( A ) = alpha * A or op( A ) = alpha * A**T + * op( A ) = A or op( A ) = A**T + * op( B ) is one of + * op( B ) = alpha * B or op( B ) = alpha * B**T + * op( B ) = B or op( B ) = B**T + * @param[in] transa specifies the form of op( A ) to be used in + the matrix multiplication. + * @param[in] transb specifies the form of op( B ) to be used in + the matrix multiplication. + * @param[in] packa specifies whether to reorder op( A ). + * @param[in] packb specifies whether to reorder op( B ). + * @param[in] m specifies the number of rows of the matrix + op( A ) and of the matrix C. + * @param[in] n specifies the number of columns of the matrix + op( B ) and the number of columns of the matrix C. + * @param[in] k specifies the number of columns of the matrix + op( A ) and the number of rows of the matrix op( B ). + * @param[in] ap specifies pointer which points to the first element of ap. + * @param[in] lda specifies the leading dimension of ap. + * @param[in] bp specifies pointer which points to the first element of bp. + * @param[in] ldb specifies the leading dimension of bp. + * @param[in] beta specifies the scalar beta. + * @param[in,out] cp specifies pointer which points to the first element of cp. + * @param[in] ldc specifies the leading dimension of cp. + */ + +template +static void gemm_compute_(char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + T unit_alpha = 1.0; + err_t err = BLIS_SUCCESS; + if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + sgemm_compute_( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + sgemm_compute_( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + dgemm_compute_( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + dgemm_compute_( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else + throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_compute_()."); +} + +template +static void cblas_gemm_compute(char storage, char transa, char transb, char pcka, char pckb, + gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, + T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc) +{ + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + testinghelpers::char_to_cblas_order( storage, &cblas_order ); + testinghelpers::char_to_cblas_trans( transa, &cblas_transa ); + testinghelpers::char_to_cblas_trans( transb, &cblas_transb ); + + T unit_alpha = 1.0; + CBLAS_IDENTIFIER cblas_identifierA = CblasAMatrix; + CBLAS_IDENTIFIER cblas_identifierB = CblasBMatrix; + CBLAS_STORAGE cblas_packed = CblasPacked; + + err_t err = BLIS_SUCCESS; + + if constexpr (std::is_same::value) + { + if ( ( pcka == 'p' || pcka == 'P' ) && ( pckb == 'p' || pckb == 'P' ) ) + { + gtint_t bufSizeA = cblas_sgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + gtint_t bufSizeB = cblas_sgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, unit_alpha, bp, ldb, bBuffer ); + + cblas_sgemm_compute( cblas_order, cblas_packed, cblas_packed, + m, n, k, aBuffer, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( pcka == 'p' || pcka == 'P' ) + { + gtint_t bufSizeA = cblas_sgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + + cblas_sgemm_compute( cblas_order, cblas_packed, cblas_transb, + m, n, k, aBuffer, lda, bp, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + } + else if ( pckb == 'p' || pckb == 'P' ) + { + gtint_t bufSizeB = cblas_sgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, *alpha, bp, ldb, bBuffer ); + + cblas_sgemm_compute( cblas_order, cblas_transa, cblas_packed, + m, n, k, ap, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( bBuffer ); + } + else + { + cblas_sgemm_compute( cblas_order, cblas_transa, cblas_transb, + m, n, k, ap, lda, bp, ldb, *beta, cp, ldc ); + } + } + else if constexpr (std::is_same::value) + { + if ( ( pcka == 'p' || pcka == 'P' ) && ( pckb == 'p' || pckb == 'P' ) ) + { + gtint_t bufSizeA = cblas_dgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + gtint_t bufSizeB = cblas_dgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, unit_alpha, bp, ldb, bBuffer ); + + cblas_dgemm_compute( cblas_order, cblas_packed, cblas_packed, + m, n, k, aBuffer, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( pcka == 'p' || pcka == 'P' ) + { + gtint_t bufSizeA = cblas_dgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + + cblas_dgemm_compute( cblas_order, cblas_packed, cblas_transb, + m, n, k, aBuffer, lda, bp, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + } + else if ( pckb == 'p' || pckb == 'P' ) + { + gtint_t bufSizeB = cblas_dgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, *alpha, bp, ldb, bBuffer ); + + cblas_dgemm_compute( cblas_order, cblas_transa, cblas_packed, + m, n, k, ap, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( bBuffer ); + } + else + { + cblas_dgemm_compute( cblas_order, cblas_transa, cblas_transb, + m, n, k, ap, lda, bp, ldb, *beta, cp, ldc ); + } + } + else + { + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: Invalid typename in cblas_gemm_compute()."); + } +} + +template +static void gemm_compute( char storage, char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ +#ifdef TEST_BLAS + if( storage == 'c' || storage == 'C' ) + gemm_compute_( transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLAS interface cannot be tested for row-major order."); + +#elif TEST_CBLAS + cblas_gemm_compute( storage, transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); +#elif TEST_BLIS_TYPED + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLIS interfaces not yet implemented for pack and compute BLAS extensions."); +#else + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: No interfaces are set to be tested."); +#endif +} \ No newline at end of file diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp new file mode 100644 index 0000000000..c70a048bca --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp @@ -0,0 +1,222 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class GEMM_Compute_IIT_ERS_Test : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(GEMM_Compute_IIT_ERS_Test, TypeParam); + +using namespace testinghelpers::IIT; + +#ifdef TEST_BLAS + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM Compute): + 1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' || TRANSA != 'P' (info = 1) + 2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' || TRANSB != 'P' (info = 2) + 3. When m < 0 (info = 3) + 4. When n < 0 (info = 4) + 5. When k < 0 (info = 5) + 6. When lda < max(1, thresh) (info = 7), thresh set based on TRANSA value + 7. When ldb < max(1, thresh) (info = 9), thresh set based on TRANSB value + 8. When ldc < max(1, n) (info = 12) +*/ + +// When info == 1 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transa) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 2 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 3 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 4 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 5 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, k_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 7 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_lda) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 9 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 12 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +/* + Early Return Scenarios(ERS) : + + The GEMM Compute API is expected to return early in the following cases: + + 1. When m == 0. + 2. When n == 0. +*/ + +// When m = 0 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When n = 0 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} +#endif \ No newline at end of file diff --git a/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp new file mode 100644 index 0000000000..e261f65835 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp @@ -0,0 +1,189 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" + +class SGemmComputeTest : + public ::testing::TestWithParam> {}; + +TEST_P(SGemmComputeTest, RandomData) +{ +// printf("SGemmCompute_test!!\n"); + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // denotes whether matrix a is packed (p) or unpacked (u) + char packa = std::get<3>(GetParam()); + // denotes whether matrix b is packed (p) or unpacked (u) + char packb = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // matrix size k + gtint_t k = std::get<7>(GetParam()); + // specifies alpha value + T alpha = std::get<8>(GetParam()); + // specifies beta value + T beta = std::get<9>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<10>(GetParam()); + gtint_t ldb_inc = std::get<11>(GetParam()); + gtint_t ldc_inc = std::get<12>(GetParam()); + + // Set the threshold for the errors: + float intermediate = (float)m*n*k; + float thresh = 10*intermediate*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +class SGemmComputeTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + char pka = std::get<3>(str.param); + char pkb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + float alpha = std::get<8>(str.param); + float beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); +#ifdef TEST_BLAS + std::string str_name = "sgemm_compute_"; +#elif TEST_CBLAS + std::string str_name = "cblas_sgemm_compute"; +#else //#elif TEST_BLIS_TYPED + // BLIS interface not yet implemented for pack and compute APIs. + std::string str_name = "blis_sgemm_compute"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + pka + pkb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); + str_name = str_name + "_a" + alpha_str; + std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); + str_name = str_name + "_b" + beta_str; + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + SGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::SGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes + SGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Values(143, 145), // m (MC - 1, MC + 1) + ::testing::Values(8159, 8161), // n (NC - 1, NC + 1) + ::testing::Values(511, 513), // k (KC - 1, KC + 1) + ::testing::Values(1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::SGemmComputeTestPrint() + ); \ No newline at end of file diff --git a/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h new file mode 100644 index 0000000000..7d1016941b --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h @@ -0,0 +1,79 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "gemm_compute.h" +#include "level3/ref_gemm_compute.h" +#include "inc/check_error.h" +#include +#include + +template +void test_gemm_compute( char storage, char trnsa, char trnsb, char pcka, char pckb, + gtint_t m, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + T alpha, T beta, double thresh ) +{ + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + + // Create a copy of c so that we can check reference results. + std::vector c_ref(c); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemm_compute( storage, trnsa, trnsb, pcka, pckb, m, n, k, &alpha, a.data(), lda, + b.data(), ldb, &beta, c.data(), ldc ); + + //---------------------------------------------------------- + // Call reference implementation. + //---------------------------------------------------------- + testinghelpers::ref_gemm_compute( storage, trnsa, trnsb, pcka, pckb, m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); +} \ No newline at end of file