Skip to content

Commit

Permalink
[SYCL][CUDA][Matrix] Add initial support for Tensorcore matrix ext (#…
Browse files Browse the repository at this point in the history
…4696)

Initial Implementation based on the new matrix extension
supporting Nvidia Tensorcore, #4695, that is adapted from
the AMX matrix extension.
Only double data type matrix elements are initially supported.

Signed-off-by: jack.kirk <[email protected]>
  • Loading branch information
JackAKirk authored Nov 8, 2021
1 parent 9340c96 commit 711ba58
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 0 deletions.
259 changes: 259 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
#pragma once

#include <CL/sycl/detail/defines_elementary.hpp>
#include <immintrin.h>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {
namespace experimental::matrix {

enum class matrix_use { a, b, accumulator };

enum class matrix_layout { row_major, col_major, packed_a, packed_b };

template <typename T, matrix_use MT, size_t Rows = sycl::dynamic_extent,
size_t Cols = sycl::dynamic_extent,
matrix_layout Layout = matrix_layout::row_major,
typename Group = sycl::sub_group, typename Cond = void>
struct joint_matrix {
joint_matrix(Group g) {}
};

// The enable_if_t usage in this file is used to disable the
// matrix_layout::packed case which is not compatible with the Nvidia cuda
// backend.
template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::a, 8, 4, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
double data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::b, 4, 8, Layout, sycl::sub_group,
typename std::enable_if_t<(Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major)>> {
double data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
double data[2];
};

} // namespace experimental::matrix

namespace detail {
using namespace experimental;

template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
struct joint_matrix_load_impl {
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
multi_ptr<T, Space> src, size_t stride);
};

template <matrix::matrix_layout Layout> constexpr int get_layout_id();

template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
return 0;
}

template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
return 1;
}

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::a, 8, 4, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::b, 4, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <typename T, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
struct joint_matrix_store_impl {
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
NumCols, Layout> &src,
multi_ptr<T, Space> dst, size_t stride);
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_store_impl<
double, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &src,
multi_ptr<double, Space> dst, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
get_layout_id<Layout>());
#endif
#endif
}
};

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC, typename Cond = void>
struct joint_matrix_mad_impl {
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
C);
};

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::row_major>() {
return 0;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::col_major>() {
return 1;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::row_major>() {
return 2;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::col_major>() {
return 3;
}

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC>
struct joint_matrix_mad_impl<
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC,
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major ||
LayoutA == matrix::matrix_layout::col_major) &&
(LayoutB == matrix::matrix_layout::row_major ||
LayoutB == matrix::matrix_layout::col_major) &&
(LayoutC == matrix::matrix_layout::row_major ||
LayoutC == matrix::matrix_layout::col_major)>> {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A,
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B,
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
LayoutC>
C) {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
D;

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
#endif
#endif

return D;
}
};

} // namespace detail

namespace experimental::matrix {

template <typename Group, typename T, matrix_use MT, size_t NumRows,
size_t NumCols, matrix_layout Layout, access::address_space Space>
void joint_matrix_load(
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load(
res, src, stride);
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout, access::address_space Space>
void joint_matrix_store(Group sg,
joint_matrix<T, matrix_use::accumulator, NumRows,
NumCols, Layout, Group> &src,
multi_ptr<T, Space> dst, size_t stride) {
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store(
src, dst, stride);
}

template <typename Group, typename T1, typename T2, std::size_t M,
std::size_t K, std::size_t N, matrix_layout LayoutA,
matrix_layout LayoutB, matrix_layout LayoutC>
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
joint_matrix_mad(
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
LayoutC>{}
.mad(A, B, C);
}

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
3 changes: 3 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
#include <sycl/ext/oneapi/matrix/static-query.hpp>
#endif
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
#endif
101 changes: 101 additions & 0 deletions sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// REQUIRES: gpu, cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, K define the sizes of dimensions of the three matrix types (a, b,
// accumulator) used per subgroup operation.
constexpr int M = 8; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 8; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 4; // number of cols of a/number of rows of b.

double A[M * K];
double B[K * N];
double C[M * N];
double D[M * N];

int main() {

buffer<double, 1> bufA(A, range<1>(M * K));
buffer<double, 1> bufB(B, range<1>(K * N));
buffer<double, 1> bufC(C, range<1>(M * N));
buffer<double, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_row>(
nd_range<2>({1, 32}, {1, 32}), [=
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<double, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major>
sub_a;

joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class col_col>(
nd_range<2>({1, 32}, {1, 32}), [=
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<double, matrix_use::accumulator, M, N,
matrix_layout::col_major>
sub_c;

joint_matrix<double, matrix_use::a, M, K, matrix_layout::col_major>
sub_a;

joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), M);
});
});

return 0;
};

0 comments on commit 711ba58

Please sign in to comment.