Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host IR: add GetCurrentStream #3605

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
7 changes: 7 additions & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
setCurrentCUDAStream(getCUDAStream(set_current_stream->stream()));
}

void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch {
private:
using OptOutDispatch::handle;
void handle(SetCurrentStream* set_current_stream) override;
void handle(GetCurrentStream* get_current_stream) override;
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
Expand Down
16 changes: 16 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
return ss.str();
}

Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr {
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
GetCurrentStream(GetCurrentStream&& other) = delete;
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::GetCurrentStream";
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
}
};

class Wait : public Expr {
public:
using Expr::Expr;
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/test_host_irs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) {
c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, HostIrGetCurrentStream) {
auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
auto get_stream = IrBuilder::create<GetCurrentStream>();
auto current_stream = get_stream->stream();
auto other_stream = IrBuilder::create<Stream>();
hic->pushBackTopLevelExprs(get_stream);
hic->pushBackTopLevelExprs(IrBuilder::create<SetCurrentStream>(other_stream));
hic->pushBackTopLevelExprs(
IrBuilder::create<SetCurrentStream>(current_stream));

auto cuda_stream = c10::cuda::getStreamFromPool();
setCurrentCUDAStream(cuda_stream);

HostIrEvaluator hie(std::move(hic));
hie.runWithInput({});

EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0));
}

TEST_F(StreamTest, ByIndex) {
constexpr int64_t kStreamIndex1 = 2;
constexpr int64_t kStreamIndex2 = 3;
Expand Down
Loading