Skip to content

Commit

Permalink
Add GMMA shape m64n40k16 (#1864)
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao authored Oct 22, 2024
1 parent 08101d9 commit 5b50a8f
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
10 changes: 10 additions & 0 deletions include/cute/arch/mma_sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,11 @@ ss_op_selector()
else if constexpr (Tile_N % 48 == 0) {
return SM90::GMMA::MMA_64x48x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
#endif
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
else if constexpr (Tile_N % 40 == 0) {
return SM90::GMMA::MMA_64x40x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
#endif
else if constexpr (Tile_N % 32 == 0) {
return SM90::GMMA::MMA_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
Expand Down Expand Up @@ -920,6 +925,11 @@ ss_op_selector()
else if constexpr (Tile_N % 48 == 0) {
return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
#endif
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
else if constexpr (Tile_N % 40 == 0) {
return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
#endif
else if constexpr (Tile_N % 32 == 0) {
return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
Expand Down
110 changes: 110 additions & 0 deletions include/cute/arch/mma_sm90_gmma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,61 @@ struct MMA_64x32x16_F32F16F16_RS

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x40x16 F32+=F16*F16
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
struct MMA_64x40x16_F32F16F16_SS
{
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = float[20];

CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
float & d00, float & d01, float & d02, float & d03,
float & d04, float & d05, float & d06, float & d07,
float & d08, float & d09, float & d10, float & d11,
float & d12, float & d13, float & d14, float & d15,
float & d16, float & d17, float & d18, float & d19,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p, %23, %24, %25, %26;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a),
"l"(desc_b),
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}
};
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x48x16 F32+=F16*F16
template <
Expand Down Expand Up @@ -5442,6 +5497,61 @@ struct MMA_64x32x16_F32BF16BF16_RS

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x40x16 F32+=BF16*BF16
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
struct MMA_64x40x16_F32BF16BF16_SS
{
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = float[20];

CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
float & d00, float & d01, float & d02, float & d03,
float & d04, float & d05, float & d06, float & d07,
float & d08, float & d09, float & d10, float & d11,
float & d12, float & d13, float & d14, float & d15,
float & d16, float & d17, float & d18, float & d19,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p, %23, %24, %25, %26;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a),
"l"(desc_b),
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}
};
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x48x16 F32+=BF16*BF16
template <
Expand Down
69 changes: 69 additions & 0 deletions include/cute/atom/mma_traits_sm90_gmma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;

using CLayout_64x40 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _5>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;

using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;

Expand Down Expand Up @@ -1773,6 +1776,39 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)

template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>;

template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;

using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;

using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;

GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)

template <
GMMA::Major tnspA,
GMMA::Major tnspB,
Expand Down Expand Up @@ -2846,6 +2882,39 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)

template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>;

template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;

using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;

using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;

GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)

template <
GMMA::Major tnspA,
GMMA::Major tnspB,
Expand Down

0 comments on commit 5b50a8f

Please sign in to comment.