Skip to content

Commit

Permalink
Host IR: add GetCurrentStream (#3605)
Browse files Browse the repository at this point in the history
# What

adds the primitive `GetCurrentStream` to Host Ir stack.

# Why

needed for 
- #3606

The idea is that if we want to use multiple stream internally, we need
before hand to capture the user stream and to set it back to being the
active stream when returning
  • Loading branch information
samnordmann authored Dec 23, 2024
1 parent cd2b3eb commit 99fb12b
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
7 changes: 7 additions & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
setCurrentCUDAStream(getCUDAStream(set_current_stream->stream()));
}

void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch {
private:
using OptOutDispatch::handle;
void handle(SetCurrentStream* set_current_stream) override;
void handle(GetCurrentStream* get_current_stream) override;
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
Expand Down
16 changes: 16 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
return ss.str();
}

Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr {
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
GetCurrentStream(GetCurrentStream&& other) = delete;
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::GetCurrentStream";
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
}
};

class Wait : public Expr {
public:
using Expr::Expr;
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/test_host_irs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) {
c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, HostIrGetCurrentStream) {
auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
auto get_stream = IrBuilder::create<GetCurrentStream>();
auto current_stream = get_stream->stream();
auto other_stream = IrBuilder::create<Stream>();
hic->pushBackTopLevelExprs(get_stream);
hic->pushBackTopLevelExprs(IrBuilder::create<SetCurrentStream>(other_stream));
hic->pushBackTopLevelExprs(
IrBuilder::create<SetCurrentStream>(current_stream));

auto cuda_stream = c10::cuda::getStreamFromPool();
setCurrentCUDAStream(cuda_stream);

HostIrEvaluator hie(std::move(hic));
hie.runWithInput({});

EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, ByIndex) {
constexpr int64_t kStreamIndex1 = 2;
constexpr int64_t kStreamIndex2 = 3;
Expand Down

0 comments on commit 99fb12b

Please sign in to comment.