diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index bfad4af23d3..b0e4b751c8a 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -523,13 +523,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setMemoryType(MemoryType::Local); d_smem->setMemoryType(MemoryType::Shared); - // Set LoadStoreOp - // TODO: extend support when mma is not cast to half - NVF_CHECK( - dataTypeSize(dc->dtype()) == 2, - "We support use_smem_epilogue on Hopper only when the output is 16-bit"); + auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2; if (store_with_stmatrix) { + // Set LoadStoreOp d_smem->definition()->as()->setOpType( LoadStoreOpType::StMatrix); }