@@ -38,7 +38,8 @@ RtpProcessGroup::RtpProcessGroup(RtpProcessGroupType type) {
3838}
3939
4040void RtpProcessGroup::broadcast (std::vector<torch::Tensor>& input, int rootRank) {
41- std::vector<BufferPtr> buffers;
41+ ScopedCUDAStreamContext stream_ctx (device_);
42+ std::vector<BufferPtr> buffers;
4243 for (auto & tensor : input) {
4344 buffers.push_back (torchTensor2Buffer (tensor));
4445 }
@@ -67,32 +68,36 @@ ReduceOp getReduceOp(c10d::ReduceOp reduce_op) {
6768
6869std::vector<torch::Tensor> RtpProcessGroup::all_reduce (std::vector<torch::Tensor>& input) {
6970 RTP_LLM_CHECK_WITH_INFO (input.size () == 1 , " AllReduce input size must be 1 , but got %d" , input.size ());
70- auto tensor = input[0 ];
71- auto dest_tensor = torch::empty_like (tensor);
72- ReduceOp reduce_op = ReduceOp::Sum;
71+ ScopedCUDAStreamContext stream_ctx (device_);
72+ auto tensor = input[0 ];
73+ auto dest_tensor = torch::empty_like (tensor);
74+ ReduceOp reduce_op = ReduceOp::Sum;
7375 device_->allReduce ({torchTensor2Buffer (tensor), reduce_op, false , mode_, torchTensor2Buffer (dest_tensor)});
7476 check_cuda_error ();
7577 return {dest_tensor};
7678}
7779
7880void RtpProcessGroup::send (std::vector<torch::Tensor>& input, int dst_rank) {
7981 RTP_LLM_CHECK_WITH_INFO (input.size () == 1 , " Send input size must be 1 , but got %d" , input.size ());
80- BatchSendRecvParams params;
82+ ScopedCUDAStreamContext stream_ctx (device_);
83+ BatchSendRecvParams params;
8184 params.p2p_params .push_back ({SendRecvType::kSend , torchTensor2Buffer (input[0 ]), dst_rank});
8285 device_->batchSendRecv (params, mode_);
8386 check_cuda_error ();
8487}
8588
8689void RtpProcessGroup::recv (std::vector<torch::Tensor>& input, int src_rank) {
8790 RTP_LLM_CHECK_WITH_INFO (input.size () == 1 , " Send input size must be 1 , but got %d" , input.size ());
88- BatchSendRecvParams params;
91+ ScopedCUDAStreamContext stream_ctx (device_);
92+ BatchSendRecvParams params;
8993 params.p2p_params .push_back ({SendRecvType::kRecv , torchTensor2Buffer (input[0 ]), src_rank});
9094 device_->batchSendRecv (params, mode_);
9195 check_cuda_error ();
9296}
9397
9498std::vector<torch::Tensor> RtpProcessGroup::all_gather (std::vector<torch::Tensor>& input) {
9599 RTP_LLM_CHECK_WITH_INFO (input.size () == 1 , " AllGather input size must be 1 , but got %d" , input.size ());
100+ ScopedCUDAStreamContext stream_ctx (device_);
96101 auto output = torch::empty ({input[0 ].size (0 ), input[0 ].size (1 ) * world_size_}, input[0 ].options ());
97102 device_->allGather ({{torchTensor2Buffer (output)}, mode_, {torchTensor2Buffer (input[0 ])}, false });
98103 check_cuda_error ();
0 commit comments