diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 688cf2e5e12e2..d1e39b2a77330 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -81,6 +81,15 @@ void matrix_vnni(unsigned int rows, unsigned int cols, T *src, T *dest, } } +template +void matrix_transpose(unsigned int rows, unsigned int cols, T *dst, T *src) { + for (unsigned int i = 0; i < rows; i++) { + for (unsigned int j = 0; j < cols; j++) { + dst[i + j * rows] = src[i * cols + j]; + } + } +} + template void matrix_fill(unsigned int rows, unsigned int cols, T *src, T val) { for (unsigned int i = 0; i < rows; i++) { @@ -128,11 +137,12 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) { } } -template +template bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (!exact && (std::is_same_v || + std::is_same_v)) { float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]); if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) { std::cout << "Incorrect result in matrix. " @@ -142,9 +152,10 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { << ", Epsilon: " << FLOAT_EPSILON << "\n"; return false; } - } else if constexpr (std::is_same_v) { + } else if constexpr (exact || std::is_same_v) { if (src[i * cols + j] != ref[i * cols + j]) { - std::cout << "Incorrect result in matrix. Ref: " << ref[i * cols + j] + std::cout << "Incorrect result in matrix." << "i: " << i + << ", j: " << j << ", Ref: " << ref[i * cols + j] << ", Val: " << src[i * cols + j] << "\n"; return false; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 44429fec6a65d..3564472c5d958 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -1,70 +1,102 @@ -#include -#include - using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; -constexpr size_t TM = 8; -constexpr size_t TK = 16; +template +void matrix_load_and_store(T1 *input, T1 *out_col_major, T1 *out_row_major, + queue q) { + size_t M = NUM_ROWS; + size_t N = NUM_COLS; -template -void matrix_load_store(T1 *C, queue q) { - size_t M = NUM_ROWS_C; - size_t N = NUM_COLS_C; + static_assert((NUM_ROWS % TM) == 0); + static_assert((NUM_COLS % TN) == 0); size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; - auto pC = address_space_cast(C); + auto p_input = address_space_cast(input); + + auto p_out_col_major = + address_space_cast(out_col_major); + auto p_out_row_major = + address_space_cast(out_row_major); q.submit([&](handler &cgh) { cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] - - { - // The submatrix API has to be accessed by all the workitems in - // a subgroup these functions will be called once by the - // subgroup no code divergence between the workitems + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_c; - // for transposeC - // which TN x TM in N x M: - // M x N => TM x N => TM x TN => TN x TM - // m=sg_startx - // sg_starty/SG_SZ - // linear_index = M * (sg_starty/SG_SZ *TN) + TM *sg_startx - joint_matrix_load(sg, sub_c, - pC + M * (sg_starty / SG_SZ * TN) + TM * sg_startx, - M, layout::col_major); - joint_matrix_store( - sg, sub_c, pC + M * (sg_starty / SG_SZ * TN) + TM * sg_startx, M, - layout::col_major); + joint_matrix sub_matrix; + + auto row_major_offset = + (sg_startx * TM) * N + (sg_starty / SG_SZ * TN); + auto col_major_offset = + (sg_startx * TM) + (sg_starty / SG_SZ * TN) * M; + + joint_matrix_load(sg, sub_matrix, p_input + col_major_offset, M, + layout::col_major); + + joint_matrix_store(sg, sub_matrix, + p_out_col_major + row_major_offset, N, + layout::row_major); + + joint_matrix_store(sg, sub_matrix, + p_out_row_major + col_major_offset, M, + layout::col_major); }); // parallel for }).wait(); } -int main() { - static constexpr size_t MATRIX_M = 1024; - static constexpr size_t MATRIX_N = 1024; +template void run_matrix_test() { + static constexpr size_t MATRIX_M = TM * 16; + static constexpr size_t MATRIX_N = TN * 16; queue q; - float *C = malloc_shared(MATRIX_M * MATRIX_N, q); - float *D = malloc_shared(MATRIX_M * MATRIX_N, q); + float *input = malloc_shared(MATRIX_M * MATRIX_N, q); + float *out_col_major = malloc_shared(MATRIX_M * MATRIX_N, q); + float *out_row_major = malloc_shared(MATRIX_M * MATRIX_N, q); + float *ref_col_major = malloc_shared(MATRIX_M * MATRIX_N, q); - matrix_rand(MATRIX_M, MATRIX_N, C, (float)5.0); - matrix_copy(MATRIX_M, MATRIX_N, C, D); + // input is column majot matrix so it is of NxM shape + matrix_rand(MATRIX_N, MATRIX_M, input, (float)5.0); + matrix_fill(MATRIX_M, MATRIX_N, out_col_major, (float)0); + matrix_fill(MATRIX_N, MATRIX_M, out_row_major, (float)0); + matrix_transpose(MATRIX_N, MATRIX_M, ref_col_major, input); - matrix_load_store(C, q); + matrix_load_and_store(input, out_col_major, + out_row_major, q); - bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D); + // we use exact comparison as no low precision calculation is used in this + // test + std::cout << "compare results for TM " << TM << "\n"; + bool res = matrix_compare(MATRIX_M, MATRIX_N, + out_col_major, ref_col_major) && + matrix_compare(MATRIX_N, MATRIX_M, + out_row_major, input); + free(input, q); + free(out_col_major, q); + free(out_row_major, q); + free(ref_col_major, q); + assert(res); +} + +int main() { + run_matrix_test<8>(); + run_matrix_test<7>(); + run_matrix_test<6>(); + run_matrix_test<5>(); + run_matrix_test<4>(); + run_matrix_test<3>(); + run_matrix_test<2>(); + run_matrix_test<1>(); - std::cout << (res ? "passed" : "failed") << std::endl; - return !res; + std::cout << "Passed\n"; + return 0; }