Skip to content

Commit

Permalink
Host IR: make stream synchronization non blocking (#3608)
Browse files Browse the repository at this point in the history
# What

Make stream synchronization non-blocking from the CPU point of view

# Why

Needed for achieving overlap in 
- #3606

before this patch:
![Screenshot 2024-12-18 at 12 08
25](https://github.com/user-attachments/assets/f5c84282-ea85-4cb8-8a60-538cd91cfa1c)
after this patch
![Screenshot 2024-12-18 at 12 08
05](https://github.com/user-attachments/assets/25537a5d-3e33-4ff8-baf4-4f013c1ed230)


# How 

Before this patch, the host IR `Synchronize` would call
`c10::synchronize()` on the cuda stream, which makes the CPU blocks
until stream completion. With this patch, we synchronize the current
stream with a given stream through a `cudaEvent` and the API
`cudaStreamWaitEvent`.
  • Loading branch information
samnordmann authored Dec 23, 2024
1 parent 410e48f commit cd2b3eb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
17 changes: 15 additions & 2 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator(
HostIrEvaluatorParams params)
: container_(std::move(container)),
communicator_(communicator),
params_(params) {
params_(params),
my_device_index_(communicator_ ? communicator_->deviceId() : 0) {
const DeviceIdxType device_index =
(communicator_ != nullptr && communicator_->is_available())
? communicator_->deviceId()
Expand Down Expand Up @@ -274,7 +275,19 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
getCUDAStream(synchronize->stream()).synchronize();
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))
.stream();
cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream();

cudaEvent_t event = {};
NVFUSER_CUDA_RT_SAFE_CALL(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event));
}

void HostIrEvaluator::handle(PostOnStream* post_ir) {
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class HostIrEvaluator final : public OptOutDispatch {
using StreamKey = std::variant<int64_t, Stream*>;
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_device_index_;
};

} // namespace hir
Expand Down
3 changes: 3 additions & 0 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <cuda_utils.h>
#include <multidevice/communicator.h>
#include <options.h>

Expand Down Expand Up @@ -196,6 +197,8 @@ Communicator::Communicator(
return;
}

NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank_));

#ifdef NVFUSER_DISTRIBUTED
c10d::TCPStoreOptions store_opts;
{
Expand Down

0 comments on commit cd2b3eb

Please sign in to comment.