-
Notifications
You must be signed in to change notification settings - Fork 745
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL][CUDA][Matrix] Add initial support for Tensorcore matrix ext (#…
…4696) Initial Implementation based on the new matrix extension supporting Nvidia Tensorcore, #4695, that is adapted from the AMX matrix extension. Only double data type matrix elements are initially supported. Signed-off-by: jack.kirk <[email protected]>
- Loading branch information
Showing
3 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
259 changes: 259 additions & 0 deletions
259
sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
#pragma once | ||
|
||
#include <CL/sycl/detail/defines_elementary.hpp> | ||
#include <immintrin.h> | ||
|
||
__SYCL_INLINE_NAMESPACE(cl) { | ||
namespace sycl { | ||
namespace ext { | ||
namespace oneapi { | ||
namespace experimental::matrix { | ||
|
||
enum class matrix_use { a, b, accumulator }; | ||
|
||
enum class matrix_layout { row_major, col_major, packed_a, packed_b }; | ||
|
||
template <typename T, matrix_use MT, size_t Rows = sycl::dynamic_extent, | ||
size_t Cols = sycl::dynamic_extent, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
typename Group = sycl::sub_group, typename Cond = void> | ||
struct joint_matrix { | ||
joint_matrix(Group g) {} | ||
}; | ||
|
||
// The enable_if_t usage in this file is used to disable the | ||
// matrix_layout::packed case which is not compatible with the Nvidia cuda | ||
// backend. | ||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::a, 8, 4, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
double data[1]; | ||
}; | ||
|
||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::b, 4, 8, Layout, sycl::sub_group, | ||
typename std::enable_if_t<(Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major)>> { | ||
double data[1]; | ||
}; | ||
|
||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
double data[2]; | ||
}; | ||
|
||
} // namespace experimental::matrix | ||
|
||
namespace detail { | ||
using namespace experimental; | ||
|
||
template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols, | ||
matrix::matrix_layout Layout, access::address_space Space, | ||
typename Cond = void> | ||
struct joint_matrix_load_impl { | ||
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res, | ||
multi_ptr<T, Space> src, size_t stride); | ||
}; | ||
|
||
template <matrix::matrix_layout Layout> constexpr int get_layout_id(); | ||
|
||
template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() { | ||
return 0; | ||
} | ||
|
||
template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() { | ||
return 1; | ||
} | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::a, 8, 4, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void | ||
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::b, 4, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void | ||
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
matrix::matrix_layout Layout, access::address_space Space, | ||
typename Cond = void> | ||
struct joint_matrix_store_impl { | ||
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows, | ||
NumCols, Layout> &src, | ||
multi_ptr<T, Space> dst, size_t stride); | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_store_impl< | ||
double, 8, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
Layout> &src, | ||
multi_ptr<double, Space> dst, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, | ||
get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N, | ||
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB, | ||
matrix::matrix_layout LayoutC, typename Cond = void> | ||
struct joint_matrix_mad_impl { | ||
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC> | ||
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A, | ||
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B, | ||
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC> | ||
C); | ||
}; | ||
|
||
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB> | ||
constexpr int get_layout_pair_id(); | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major, | ||
matrix::matrix_layout::row_major>() { | ||
return 0; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major, | ||
matrix::matrix_layout::col_major>() { | ||
return 1; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major, | ||
matrix::matrix_layout::row_major>() { | ||
return 2; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major, | ||
matrix::matrix_layout::col_major>() { | ||
return 3; | ||
} | ||
|
||
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB, | ||
matrix::matrix_layout LayoutC> | ||
struct joint_matrix_mad_impl< | ||
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC, | ||
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major || | ||
LayoutA == matrix::matrix_layout::col_major) && | ||
(LayoutB == matrix::matrix_layout::row_major || | ||
LayoutB == matrix::matrix_layout::col_major) && | ||
(LayoutC == matrix::matrix_layout::row_major || | ||
LayoutC == matrix::matrix_layout::col_major)>> { | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC> | ||
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A, | ||
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B, | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
LayoutC> | ||
C) { | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC> | ||
D; | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, | ||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
#endif | ||
#endif | ||
|
||
return D; | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
|
||
namespace experimental::matrix { | ||
|
||
template <typename Group, typename T, matrix_use MT, size_t NumRows, | ||
size_t NumCols, matrix_layout Layout, access::address_space Space> | ||
void joint_matrix_load( | ||
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res, | ||
multi_ptr<T, Space> src, size_t stride) { | ||
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load( | ||
res, src, stride); | ||
} | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout, access::address_space Space> | ||
void joint_matrix_store(Group sg, | ||
joint_matrix<T, matrix_use::accumulator, NumRows, | ||
NumCols, Layout, Group> &src, | ||
multi_ptr<T, Space> dst, size_t stride) { | ||
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store( | ||
src, dst, stride); | ||
} | ||
|
||
template <typename Group, typename T1, typename T2, std::size_t M, | ||
std::size_t K, std::size_t N, matrix_layout LayoutA, | ||
matrix_layout LayoutB, matrix_layout LayoutC> | ||
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> | ||
joint_matrix_mad( | ||
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A, | ||
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B, | ||
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) { | ||
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB, | ||
LayoutC>{} | ||
.mad(A, B, C); | ||
} | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
} // namespace sycl | ||
} // __SYCL_INLINE_NAMESPACE(cl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// REQUIRES: gpu, cuda | ||
|
||
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
||
// M, N, K define the sizes of dimensions of the three matrix types (a, b, | ||
// accumulator) used per subgroup operation. | ||
constexpr int M = 8; // number of rows of accumulator, | ||
// number of cols of b. | ||
constexpr int N = 8; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 4; // number of cols of a/number of rows of b. | ||
|
||
double A[M * K]; | ||
double B[K * N]; | ||
double C[M * N]; | ||
double D[M * N]; | ||
|
||
int main() { | ||
|
||
buffer<double, 1> bufA(A, range<1>(M * K)); | ||
buffer<double, 1> bufB(B, range<1>(K * N)); | ||
buffer<double, 1> bufC(C, range<1>(M * N)); | ||
buffer<double, 1> bufD(D, range<1>(M * N)); | ||
|
||
queue q; | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class row_row>( | ||
nd_range<2>({1, 32}, {1, 32}), [= | ||
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<double, matrix_use::accumulator, M, N, | ||
matrix_layout::row_major> | ||
sub_c; | ||
|
||
joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major> | ||
sub_a; | ||
|
||
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major> | ||
sub_b; | ||
|
||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), N); | ||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
}); | ||
}); | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class col_col>( | ||
nd_range<2>({1, 32}, {1, 32}), [= | ||
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<double, matrix_use::accumulator, M, N, | ||
matrix_layout::col_major> | ||
sub_c; | ||
|
||
joint_matrix<double, matrix_use::a, M, K, matrix_layout::col_major> | ||
sub_a; | ||
|
||
joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major> | ||
sub_b; | ||
|
||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), M); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), M); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), K); | ||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), M); | ||
}); | ||
}); | ||
|
||
return 0; | ||
}; |