-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Host IR: make stream synchronization non blocking #3608
Merged
samnordmann
merged 4 commits into
NVIDIA:main
from
samnordmann:host_irs/non_blocking_stream_synchronize
Dec 23, 2024
Merged
Host IR: make stream synchronization non blocking #3608
samnordmann
merged 4 commits into
NVIDIA:main
from
samnordmann:host_irs/non_blocking_stream_synchronize
Dec 23, 2024
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
samnordmann
commented
Dec 18, 2024
@@ -196,6 +197,8 @@ Communicator::Communicator( | |||
return; | |||
} | |||
|
|||
NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's been a long time I suspected this was going to be required at some point. Anyway, this is a recommended (if not required) practice. Without it, cudaEventRecord
throws cudaErrorInvalidResourceHandle
in a multi-GPU scenario.
Merged
2 tasks
!test |
!test |
!test |
wujingyue
approved these changes
Dec 19, 2024
Closed
2 tasks
samnordmann
added a commit
that referenced
this pull request
Jan 13, 2025
…lap: AG+GEMM layout (#3606) Stacked on top of - [x] #3608 - [x] #3605 # What Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap. We introduce a new parallel type `Stream` to account for this scheduling. More precisely, this patch enables lowering the fusion: ``` TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K] TensorView* b = makeContigTensor(2); //[K, N] TensorView* c = matmul(a, b); //[S, D, M/(S*D), N] fusion->addInput(a); fusion->addInput(b); fusion->addOutput(c); auto mesh = DeviceMesh::createForNumDevices(D); a->setDeviceMesh(mesh); b->setDeviceMesh(mesh); c->setDeviceMesh(mesh); a->axis(1)->parallelize(ParallelType::DIDx); c->axis(0)->parallelize(ParallelType::Stream); ``` to the Host Ir program (obtained from dump, using `NVFUSER_DUMP=host_ir`) ``` %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : GetCurrentStream into Stream 0 T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=fals e) T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=fals e, resets_to_zero=false) FOR i104 in iS0{i0}: SetCurrentStream to Stream ( i104 % numberOfStreams ) T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = select( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = i104 ) T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = select( T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS11{i0}, index = i104 ) Communication 46 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) Wait Communication 46 T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}) = select( T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStream6{i0}, index = i104 ) T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( i104 % numberOfStreams ) } // %HostIrContainer ``` The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments 
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What
Make stream synchronization non-blocking from the CPU point of view
Why
Needed for achieving overlap in
before this patch:


after this patch
How
Before this patch, the host IR
Synchronize
would callc10::synchronize()
on the cuda stream, which makes the CPU blocks until stream completion. With this patch, we synchronize the current stream with a given stream through acudaEvent
and the APIcudaStreamWaitEvent
.