diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..77b650b88dc 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -146,6 +146,7 @@ class Val; f(HostUnit); \ f(PostOnStream); \ f(SetCurrentStream); \ + f(GetCurrentStream); \ f(Wait); \ f(Synchronize); \ f(StartCoalescing); \ diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 69b5b9c704d..1b2554cdabb 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { setCurrentCUDAStream(getCUDAStream(set_current_stream->stream())); } +void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { + streams_.insert( + {get_current_stream->stream(), + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); +} + void HostIrEvaluator::handle(Synchronize* synchronize) { cudaStream_t current_stream = c10::cuda::getCurrentCUDAStream( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 6f9070b810a..a51dc32aed4 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch { private: using OptOutDispatch::handle; void handle(SetCurrentStream* set_current_stream) override; + void handle(GetCurrentStream* get_current_stream) override; void handle(Synchronize* synchronize) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 492b2b22aab..49b33f59823 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR(passkey.ir_container_->isA()); + auto stream = IrBuilder::createInContainer(passkey.ir_container_); + addAttribute(stream); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) + +std::string GetCurrentStream::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() + << std::endl; + return ss.str(); +} + Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 587ffc43638..82d67d6f4cc 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr { } }; +class GetCurrentStream : public Expr { + public: + using Expr::Expr; + GetCurrentStream(IrBuilderPasskey passkey); + + GetCurrentStream(const GetCurrentStream& other) = delete; + GetCurrentStream& operator=(const GetCurrentStream& other) = delete; + GetCurrentStream(GetCurrentStream&& other) = delete; + GetCurrentStream& operator=(GetCurrentStream&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::GetCurrentStream"; + } + + Stream* stream() const { + return attributes_.at(0)->as(); + } +}; + class Wait : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 64aa2a0564b..e97550309e1 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) { c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0)); } +TEST_F(StreamTest, HostIrGetCurrentStream) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + auto get_stream = IrBuilder::create(); + auto current_stream = get_stream->stream(); + auto other_stream = IrBuilder::create(); + hic->pushBackTopLevelExprs(get_stream); + hic->pushBackTopLevelExprs(IrBuilder::create(other_stream)); + hic->pushBackTopLevelExprs( + IrBuilder::create(current_stream)); + + auto cuda_stream = c10::cuda::getStreamFromPool(); + setCurrentCUDAStream(cuda_stream); + + HostIrEvaluator hie(std::move(hic)); + hie.runWithInput({}); + + EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0)); +} + TEST_F(StreamTest, ByIndex) { constexpr int64_t kStreamIndex1 = 2; constexpr int64_t kStreamIndex2 = 3;