Skip to content

Commit 219001d

Browse files
committed
rework DefualtSemiRingConfiguration struct API for clarity
1 parent a40f7bc commit 219001d

File tree

262 files changed

+35748
-35743
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

262 files changed

+35748
-35743
lines changed

bench/device/gen_simt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@
9595
using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>;
9696
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
9797
98-
using Config = typename cuasr::gemm::device::DefaultSemiRingConfiguration< //
99-
precision, precision, precision, precision, OpClass, //
100-
cuasr::{0}<precision>, cuasr::{1}<precision>, SmArch>;
98+
using Config = typename cuasr::gemm::device::DefaultSemiRingConfiguration<
99+
precision, precision, precision, precision,
100+
cuasr::{0}<precision>, cuasr::{1}<precision>, OpClass, SmArch>;
101101
102102
using AddOp = Config::AdditionOp;
103103
using MultOp = Config::MultiplicationOp;

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_binary_or_binary_and_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_minimum_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_multiplies_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_maximum_plus_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_maximum_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_multiplies_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_minimum_plus_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_nn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_nn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_nt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_nt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_tn_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_tn_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_tt_n.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_dsrgemm_tt_t.cu

+108-108
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_nn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_nn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_nt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_nt_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_tn_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_tn_t.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_tt_n.cu

+171-171
Large diffs are not rendered by default.

bench/device/sm50_simt_plus_multiplies_ssrgemm_tt_t.cu

+171-171
Large diffs are not rendered by default.

include/cuasr/gemm/device/default_srgemm_configuration.h

+14-9
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ template <
3030
typename ElementB,
3131
typename ElementC,
3232
typename ElementAccumulator,
33-
typename OperatorClass,
3433
typename AdditionOp,
3534
typename MultiplicationOp,
35+
typename OperatorClass,
3636
typename ArchTag
3737
>
3838
struct DefaultSemiRingConfiguration;
3939

40+
////////////////////////////////////////////////////////////////////////////////
41+
/////////////////////////////////// SM 50 //////////////////////////////////////
4042
////////////////////////////////////////////////////////////////////////////////
4143

4244
// Plus-Times semi-ring GEMM configuration
@@ -50,9 +52,9 @@ struct DefaultSemiRingConfiguration<
5052
Element,
5153
Element,
5254
Element,
53-
cutlass::arch::OpClassSimt,
5455
cuasr::plus<Element>,
5556
cuasr::multiplies<Element>,
57+
cutlass::arch::OpClassSimt,
5658
ArchTag> {
5759

5860
static int constexpr kAlignmentA = 1;
@@ -80,9 +82,9 @@ struct DefaultSemiRingConfiguration<
8082
Element,
8183
Element,
8284
Element,
83-
cutlass::arch::OpClassSimt,
8485
cuasr::minimum<Element>,
8586
cuasr::plus<Element>,
87+
cutlass::arch::OpClassSimt,
8688
ArchTag> {
8789

8890
static int constexpr kAlignmentA = 1;
@@ -110,9 +112,9 @@ struct DefaultSemiRingConfiguration<
110112
Element,
111113
Element,
112114
Element,
113-
cutlass::arch::OpClassSimt,
114115
cuasr::maximum<Element>,
115116
cuasr::plus<Element>,
117+
cutlass::arch::OpClassSimt,
116118
ArchTag> {
117119

118120
static int constexpr kAlignmentA = 1;
@@ -139,9 +141,9 @@ struct DefaultSemiRingConfiguration<
139141
Element,
140142
Element,
141143
Element,
142-
cutlass::arch::OpClassSimt,
143144
cuasr::maximum<Element>,
144145
cuasr::minimum<Element>,
146+
cutlass::arch::OpClassSimt,
145147
ArchTag> {
146148

147149
static int constexpr kAlignmentA = 1;
@@ -168,9 +170,9 @@ struct DefaultSemiRingConfiguration<
168170
Element,
169171
Element,
170172
Element,
171-
cutlass::arch::OpClassSimt,
172173
cuasr::minimum<Element>,
173174
cuasr::maximum<Element>,
175+
cutlass::arch::OpClassSimt,
174176
ArchTag> {
175177

176178
static int constexpr kAlignmentA = 1;
@@ -197,9 +199,9 @@ struct DefaultSemiRingConfiguration<
197199
Element,
198200
Element,
199201
Element,
200-
cutlass::arch::OpClassSimt,
201202
cuasr::minimum<Element>,
202203
cuasr::multiplies<Element>,
204+
cutlass::arch::OpClassSimt,
203205
ArchTag> {
204206

205207
static int constexpr kAlignmentA = 1;
@@ -226,9 +228,9 @@ struct DefaultSemiRingConfiguration<
226228
Element,
227229
Element,
228230
Element,
229-
cutlass::arch::OpClassSimt,
230231
cuasr::maximum<Element>,
231232
cuasr::multiplies<Element>,
233+
cutlass::arch::OpClassSimt,
232234
ArchTag> {
233235

234236
static int constexpr kAlignmentA = 1;
@@ -255,9 +257,9 @@ struct DefaultSemiRingConfiguration<
255257
Element,
256258
Element,
257259
Element,
258-
cutlass::arch::OpClassSimt,
259260
cuasr::binary_or<Element>,
260261
cuasr::binary_and<Element>,
262+
cutlass::arch::OpClassSimt,
261263
ArchTag> {
262264

263265
static int constexpr kAlignmentA = 1;
@@ -275,6 +277,9 @@ struct DefaultSemiRingConfiguration<
275277
};
276278

277279
////////////////////////////////////////////////////////////////////////////////
280+
/////////////////////////////////// SM 80 //////////////////////////////////////
281+
////////////////////////////////////////////////////////////////////////////////
282+
278283

279284
} // namespace device
280285
} // namespace gemm

include/cuasr/gemm/device/srgemm.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,34 @@ template <
5151
/// Threadblock-level tile size (concept: GemmShape)
5252
typename ThreadblockShape_ = typename DefaultSemiRingConfiguration<
5353
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
54-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::ThreadblockShape,
54+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::ThreadblockShape,
5555
/// Warp-level tile size (concept: GemmShape)
5656
typename WarpShape_ = typename DefaultSemiRingConfiguration<
5757
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
58-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::WarpShape,
58+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::WarpShape,
5959
/// Instruction-level tile size (concept: GemmShape)
6060
typename InstructionShape_ = typename DefaultSemiRingConfiguration<
6161
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
62-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::InstructionShape,
62+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::InstructionShape,
6363
/// Epilogue output operator
6464
typename EpilogueOutputOp_ = typename DefaultSemiRingConfiguration<
6565
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
66-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::EpilogueOutputOp,
66+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::EpilogueOutputOp,
6767
/// Threadblock-level swizzling operator
6868
typename ThreadblockSwizzle_ =
6969
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
7070
/// Number of stages used in the pipelined mainloop
7171
int Stages = DefaultSemiRingConfiguration<
7272
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
73-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kStages,
73+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kStages,
7474
/// Access granularity of A matrix in units of elements
7575
int AlignmentA = DefaultSemiRingConfiguration<
7676
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
77-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kAlignmentA,
77+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kAlignmentA,
7878
/// Access granularity of B matrix in units of elements
7979
int AlignmentB = DefaultSemiRingConfiguration<
8080
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
81-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kAlignmentB,
81+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kAlignmentB,
8282
/// If true, kernel supports split-K with serial reduction
8383
bool SplitKSerial = false
8484
>

include/cuasr/gemm/device/srgemm_splitk_parallel.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ template <
5959
/// Threadblock-level tile size (concept: GemmShape)
6060
typename ThreadblockShape_ = typename DefaultSemiRingConfiguration<
6161
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
62-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::ThreadblockShape,
62+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::ThreadblockShape,
6363
/// Warp-level tile size (concept: GemmShape)
6464
typename WarpShape_ = typename DefaultSemiRingConfiguration<
6565
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
66-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::WarpShape,
66+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::WarpShape,
6767
/// Instruction-level tile size (concept: GemmShape)
6868
typename InstructionShape_ = typename DefaultSemiRingConfiguration<
6969
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
70-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::InstructionShape,
70+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::InstructionShape,
7171
/// Epilogue output operator
7272
typename EpilogueOutputOp_ = typename DefaultSemiRingConfiguration<
7373
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
74-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::EpilogueOutputOp,
74+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::EpilogueOutputOp,
7575
/// Epilogue conversion operator
7676
typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
7777
ElementAccumulator_, DefaultSemiRingConfiguration<
7878
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
79-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::EpilogueOutputOp::kCount,
79+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::EpilogueOutputOp::kCount,
8080
ElementAccumulator_>,
8181
/// Reduction operator
8282
typename ReductionOp_ = cuasr::reduction::thread::SemiringReduce<
@@ -88,15 +88,15 @@ template <
8888
/// Number of stages used in the pipelined mainloop
8989
int Stages = DefaultSemiRingConfiguration<
9090
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
91-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kStages,
91+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kStages,
9292
/// Access granularity of A matrix in units of elements
9393
int kAlignmentA = DefaultSemiRingConfiguration<
9494
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
95-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kAlignmentA,
95+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kAlignmentA,
9696
/// Access granularity of B matrix in units of elements
9797
int kAlignmentB = DefaultSemiRingConfiguration<
9898
ElementA_, ElementB_, ElementC_, ElementAccumulator_,
99-
OperatorClass_, AdditionOp_, MultiplicationOp_, ArchTag_>::kAlignmentB
99+
AdditionOp_, MultiplicationOp_, OperatorClass_, ArchTag_>::kAlignmentB
100100
>
101101
class SrgemmSplitKParallel {
102102
public:

test/device/gen_simt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@
103103
using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>;
104104
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
105105
106-
using Config = typename cuasr::gemm::device::DefaultSemiRingConfiguration< //
107-
precision, precision, precision, precision, OpClass, //
108-
cuasr::{0}<precision>, cuasr::{1}<precision>, SmArch>;
106+
using Config = typename cuasr::gemm::device::DefaultSemiRingConfiguration<
107+
precision, precision, precision, precision,
108+
cuasr::{0}<precision>, cuasr::{1}<precision>, OpClass, SmArch>;
109109
110110
using AddOp = Config::AdditionOp;
111111
using MultOp = Config::MultiplicationOp;

0 commit comments

Comments
 (0)