Skip to content

Commit

Permalink
Fix SMEM index for C in CuTe examples (#1477)
Browse files Browse the repository at this point in the history
  • Loading branch information
joerowell authored Jul 10, 2024
1 parent e48c761 commit 843adf0
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/cute/tutorial/sgemm_1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
static_assert(is_static<CSmemLayout>::value);

CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
Expand Down
2 changes: 1 addition & 1 deletion examples/cute/tutorial/sgemm_2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
static_assert(is_static<CSmemLayout>::value);

CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
Expand Down
2 changes: 1 addition & 1 deletion examples/cute/tutorial/sgemm_sm70.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
static_assert(is_static<CSmemLayout>::value);

CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
Expand Down
2 changes: 1 addition & 1 deletion examples/cute/tutorial/sgemm_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
static_assert(is_static<CSmemLayout>::value);

CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
Expand Down
2 changes: 1 addition & 1 deletion media/docs/cute/0x_gemm_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ As is evident, these smem layouts can be almost anything. Inside the kernel, the
static_assert(is_static<CSmemLayout>::value);

CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
Expand Down

0 comments on commit 843adf0

Please sign in to comment.