diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 31afc58a775..6b7579328cc 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -240,7 +240,24 @@ class LowerToInlinePtx : public kir::ExprMutator { /*volatile=*/true, /*memory=*/false, /*readable_outputs=*/{0}})); + + // The above call is asynchronous, so we need to wait to prevent a data race + // TODO: Why is it safe to not always use zero here? + CircularBufferOptions cb_opts = + mma->inA()->as()->view()->circularBufferOptions(); + auto* commit = IrBuilder::create(AsyncOpType::WgMma); + auto* wait = IrBuilder::create( + AsyncOpType::WgMma, + /*keep_stages=*/cb_opts.stage - cb_opts.prefetch - 1); + + registerInsertBefore(mma, commit); + registerInsertBefore(mma, wait); registerRemove(mma); + + // These are needed for actually converting the nodes above into kir::Asm + // nodes properly + handle(commit); + handle(wait); } void handle(MmaOp* mma) final {