Skip to content

Commit a697d21

Browse files
committed
fix: fix
1 parent fcf20d2 commit a697d21

File tree

3 files changed

+276
-93
lines changed

3 files changed

+276
-93
lines changed

rtp_llm/models_py/bindings/common/RtpProcessGroup.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ RtpProcessGroup::RtpProcessGroup(RtpProcessGroupType type) {
3838
}
3939

4040
void RtpProcessGroup::broadcast(std::vector<torch::Tensor>& input, int rootRank) {
41-
std::vector<BufferPtr> buffers;
41+
ScopedCUDAStreamContext stream_ctx(device_);
42+
std::vector<BufferPtr> buffers;
4243
for (auto& tensor : input) {
4344
buffers.push_back(torchTensor2Buffer(tensor));
4445
}
@@ -67,32 +68,36 @@ ReduceOp getReduceOp(c10d::ReduceOp reduce_op) {
6768

6869
std::vector<torch::Tensor> RtpProcessGroup::all_reduce(std::vector<torch::Tensor>& input) {
6970
RTP_LLM_CHECK_WITH_INFO(input.size() == 1, "AllReduce input size must be 1 , but got %d", input.size());
70-
auto tensor = input[0];
71-
auto dest_tensor = torch::empty_like(tensor);
72-
ReduceOp reduce_op = ReduceOp::Sum;
71+
ScopedCUDAStreamContext stream_ctx(device_);
72+
auto tensor = input[0];
73+
auto dest_tensor = torch::empty_like(tensor);
74+
ReduceOp reduce_op = ReduceOp::Sum;
7375
device_->allReduce({torchTensor2Buffer(tensor), reduce_op, false, mode_, torchTensor2Buffer(dest_tensor)});
7476
check_cuda_error();
7577
return {dest_tensor};
7678
}
7779

7880
void RtpProcessGroup::send(std::vector<torch::Tensor>& input, int dst_rank) {
7981
RTP_LLM_CHECK_WITH_INFO(input.size() == 1, "Send input size must be 1 , but got %d", input.size());
80-
BatchSendRecvParams params;
82+
ScopedCUDAStreamContext stream_ctx(device_);
83+
BatchSendRecvParams params;
8184
params.p2p_params.push_back({SendRecvType::kSend, torchTensor2Buffer(input[0]), dst_rank});
8285
device_->batchSendRecv(params, mode_);
8386
check_cuda_error();
8487
}
8588

8689
void RtpProcessGroup::recv(std::vector<torch::Tensor>& input, int src_rank) {
8790
RTP_LLM_CHECK_WITH_INFO(input.size() == 1, "Send input size must be 1 , but got %d", input.size());
88-
BatchSendRecvParams params;
91+
ScopedCUDAStreamContext stream_ctx(device_);
92+
BatchSendRecvParams params;
8993
params.p2p_params.push_back({SendRecvType::kRecv, torchTensor2Buffer(input[0]), src_rank});
9094
device_->batchSendRecv(params, mode_);
9195
check_cuda_error();
9296
}
9397

9498
std::vector<torch::Tensor> RtpProcessGroup::all_gather(std::vector<torch::Tensor>& input) {
9599
RTP_LLM_CHECK_WITH_INFO(input.size() == 1, "AllGather input size must be 1 , but got %d", input.size());
100+
ScopedCUDAStreamContext stream_ctx(device_);
96101
auto output = torch::empty({input[0].size(0), input[0].size(1) * world_size_}, input[0].options());
97102
device_->allGather({{torchTensor2Buffer(output)}, mode_, {torchTensor2Buffer(input[0])}, false});
98103
check_cuda_error();

rtp_llm/models_py/bindings/common/RtpProcessGroup.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <torch/csrc/distributed/c10d/Types.hpp>
44
#include <torch/csrc/distributed/c10d/Backend.hpp>
5+
#include <ATen/cuda/CUDAContext.h>
56
#include <vector>
67
#include "rtp_llm/cpp/core/Types.h" // for ParallelMode
78
#include "rtp_llm/models_py/bindings/common/Torch_ext.h" // for DefaultDeviceType
@@ -15,6 +16,28 @@ enum class RtpProcessGroupType {
1516
CP_GROUP = 3,
1617
};
1718

19+
class ScopedCUDAStreamContext {
20+
public:
21+
explicit ScopedCUDAStreamContext(DefaultDeviceType* device): device_(device) {
22+
original_stream_ = device_->getStream();
23+
current_stream_ = at::cuda::getCurrentCUDAStream(at::cuda::current_device()).stream();
24+
device_->setStream(current_stream_);
25+
}
26+
27+
~ScopedCUDAStreamContext() {
28+
device_->setStream(original_stream_);
29+
}
30+
ScopedCUDAStreamContext(const ScopedCUDAStreamContext&) = delete;
31+
ScopedCUDAStreamContext& operator=(const ScopedCUDAStreamContext&) = delete;
32+
ScopedCUDAStreamContext(ScopedCUDAStreamContext&&) = delete;
33+
ScopedCUDAStreamContext& operator=(ScopedCUDAStreamContext&&) = delete;
34+
35+
private:
36+
DefaultDeviceType* device_;
37+
cudaStream_t original_stream_;
38+
cudaStream_t current_stream_;
39+
};
40+
1841
class RtpProcessGroup {
1942
public:
2043
RtpProcessGroup(RtpProcessGroupType type);

0 commit comments

Comments
 (0)