Skip to content

Commit 7dec14a

Browse files
lwfacebook-github-bot
authored andcommitted
Avoid defining RpcCUDAFuture subclass in TensorPipe agent (pytorch#56513)
Summary: Pull Request resolved: pytorch#56513 The RpcCUDAFuture class existed solely to support extracting DataPtrs from a Message class. However, this can be done more simply by using a vanilla CUDAFuture and just extracting those DataPtrs before marking it complete and passing them to markCompleted. This allows to make the DataPtr extraction logic of CUDAFuture private again. ghstack-source-id: 127035771 Test Plan: Unit tests Reviewed By: mrshenli Differential Revision: D27861064 fbshipit-source-id: b0b4df2cab7be6b4b16d5cfc888483c18fbce60e
1 parent 5ddc269 commit 7dec14a

File tree

3 files changed

+11
-27
lines changed

3 files changed

+11
-27
lines changed

aten/src/ATen/cuda/CUDAFuture.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
namespace at {
2323
namespace cuda {
2424

25-
struct TORCH_CUDA_CPP_API CUDAFuture : at::ivalue::Future {
25+
struct TORCH_CUDA_CPP_API CUDAFuture final : at::ivalue::Future {
2626
public:
2727
CUDAFuture(at::TypePtr type) : at::ivalue::Future(std::move(type)) {
2828
// Use current device to initialize currentDevice_. This is necessary
@@ -129,7 +129,8 @@ struct TORCH_CUDA_CPP_API CUDAFuture : at::ivalue::Future {
129129
}
130130
}
131131

132-
virtual std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
132+
private:
133+
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
133134
const at::IValue& value) {
134135
at::IValue::HashAliasedIValues sub_values;
135136
// Prefer getSubValues() over visit() as the latter is a silent no-op for
@@ -145,7 +146,6 @@ struct TORCH_CUDA_CPP_API CUDAFuture : at::ivalue::Future {
145146
return data_ptrs;
146147
}
147148

148-
private:
149149
// The device that was current when markCompleted was called, which we'll
150150
// restore when invoking callbacks.
151151
c10::DeviceIndex currentDevice_;

torch/csrc/distributed/rpc/tensorpipe_agent.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1321,8 +1321,13 @@ void TensorPipeAgent::markFutureAsComplete(
13211321
message{std::move(message)},
13221322
ctx{std::move(ctx)}]() mutable {
13231323
MultiStreamGuard guard(ctx);
1324+
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
1325+
for (const auto& tensor : message.tensors()) {
1326+
data_ptrs.emplace_back(tensor.storage().data_ptr());
1327+
}
13241328
atomicFuture->jitFuture->markCompleted(
1325-
IValue(c10::make_intrusive<Message>(std::move(message))));
1329+
IValue(c10::make_intrusive<Message>(std::move(message))),
1330+
std::move(data_ptrs));
13261331
// The future's callbacks may schedule further RPCs, increasing the count.
13271332
// Thus we must decrease it after completing the future, otherwise it may
13281333
// briefly dip to zero and trick join into thinking all work is done.

torch/csrc/distributed/rpc/tensorpipe_agent.h

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -276,28 +276,6 @@ class TensorPipeAgent : public RpcAgent {
276276
const std::string& remoteName,
277277
const Message& message) const;
278278

279-
#ifdef USE_CUDA_NOT_ROCM
280-
// An RPC-specific CUDAFuture subclass. It overrides the extractDataPtrs
281-
// function to handle and only handle RPC Messages.
282-
struct TORCH_CUDA_CPP_API RpcCUDAFuture final : at::cuda::CUDAFuture {
283-
public:
284-
using at::cuda::CUDAFuture::CUDAFuture;
285-
286-
protected:
287-
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
288-
const at::IValue& value) override {
289-
const auto message = value.toCustomClass<Message>();
290-
TORCH_INTERNAL_ASSERT(
291-
message, "Passed a non-Message type to RpcCUDAFuture");
292-
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
293-
for (const auto& tensor : message->tensors()) {
294-
data_ptrs.emplace_back(tensor.storage().data_ptr());
295-
}
296-
return data_ptrs;
297-
}
298-
};
299-
#endif
300-
301279
// When a request+response completes, we need to mark the future message as
302280
// complete. However, if its timeout has already expired, it already has an
303281
// error set. There is no atomic "test-and-set" way to mark a future complete
@@ -308,7 +286,8 @@ class TensorPipeAgent : public RpcAgent {
308286
AtomicJitFuture(bool noCuda = true) {
309287
#ifdef USE_CUDA_NOT_ROCM
310288
if (!noCuda) {
311-
jitFuture = std::make_shared<RpcCUDAFuture>(at::AnyClassType::get());
289+
jitFuture =
290+
std::make_shared<at::cuda::CUDAFuture>(at::AnyClassType::get());
312291
} else {
313292
#else
314293
{

0 commit comments

Comments
 (0)