Skip to content

Commit 0911ee9

Browse files
lwfacebook-github-bot
authored andcommitted
Split CUDAFuture into a .h and a .cpp file (pytorch#56514)
Summary: Pull Request resolved: pytorch#56514 rohan-varma mentioned that having CUDAFuture entirely defined in a header meant having to rebuild a whole lot of things whenever it changed. In fact there's no reason not to use a .cpp file, so here I do so. ghstack-source-id: 127035765 Test Plan: Unit tests Reviewed By: rohan-varma, mrshenli Differential Revision: D27861071 fbshipit-source-id: c209d54af9b52d3ad781db1b61f6fca02c637f32
1 parent 7dec14a commit 0911ee9

File tree

3 files changed

+156
-120
lines changed

3 files changed

+156
-120
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ filegroup(
341341
"aten/src/ATen/cuda/CUDABlas.cpp",
342342
"aten/src/ATen/cuda/CUDASolver.cpp",
343343
"aten/src/ATen/cuda/CUDAContext.cpp",
344+
"aten/src/ATen/cuda/CUDAFuture.cpp",
344345
"aten/src/ATen/cuda/CUDAGeneratorImpl.cpp",
345346
"aten/src/ATen/cuda/CUDAGraph.cpp",
346347
"aten/src/ATen/cuda/CuSparseHandlePool.cpp",

aten/src/ATen/cuda/CUDAFuture.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include <ATen/cuda/CUDAFuture.h>
2+
3+
#include <functional>
4+
#include <memory>
5+
#include <mutex>
6+
#include <utility>
7+
#include <vector>
8+
9+
#include <ATen/core/ivalue.h>
10+
#include <ATen/core/ivalue_inl.h>
11+
#include <ATen/core/jit_type.h>
12+
#include <ATen/cuda/CUDAEvent.h>
13+
#include <ATen/cuda/CUDAMultiStreamGuard.h>
14+
#include <c10/core/Allocator.h>
15+
#include <c10/core/Device.h>
16+
#include <c10/cuda/CUDACachingAllocator.h>
17+
#include <c10/cuda/CUDAFunctions.h>
18+
#include <c10/cuda/CUDAStream.h>
19+
#include <c10/macros/Export.h>
20+
#include <c10/util/intrusive_ptr.h>
21+
22+
namespace at {
23+
namespace cuda {
24+
25+
namespace {
26+
27+
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
28+
const at::IValue& value) {
29+
at::IValue::HashAliasedIValues sub_values;
30+
// Prefer getSubValues() over visit() as the latter is a silent no-op for
31+
// some unsupported types, whereas the former at least fails loudly.
32+
value.getSubValues(sub_values);
33+
34+
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
35+
for (const at::IValue& sub_value : sub_values) {
36+
if (sub_value.isTensor()) {
37+
data_ptrs.emplace_back(sub_value.toTensor().storage().data_ptr());
38+
}
39+
}
40+
return data_ptrs;
41+
}
42+
43+
} // namespace
44+
45+
CUDAFuture::CUDAFuture(at::TypePtr type) : at::ivalue::Future(std::move(type)) {
46+
// Use current device to initialize currentDevice_. This is necessary
47+
// because preMarkCompletedHook won't be called when the Future contains
48+
// an error. Uninitialized currentDevice_ could lead to crash when used
49+
// in CUDAGuard.
50+
currentDevice_ = c10::cuda::current_device();
51+
}
52+
53+
c10::intrusive_ptr<ivalue::Future> CUDAFuture::createInstance(
54+
at::TypePtr type) {
55+
return c10::make_intrusive<CUDAFuture>(std::move(type));
56+
}
57+
58+
/**
59+
* The dataPtrs field contains storage pointers of all tensors in the IValue.
60+
* This method records CUDAEvents on participating devices and uses those
61+
* CUDAEvents to synchronize streams when calling postWaitHook().
62+
* If dataPtrs does not have a value, this method will try to inspect the
63+
* given IValue by walking through all subvalues and extracting data pointers
64+
* from CUDA tensors.
65+
*/
66+
void CUDAFuture::preMarkCompletedHook(
67+
const at::IValue& value,
68+
c10::optional<std::vector<std::reference_wrapper<const at::DataPtr>>>
69+
dataPtrs) {
70+
// Start by performing all steps that can throw, before setting any field.
71+
std::vector<std::reference_wrapper<const at::DataPtr>> actualDataPtrs =
72+
dataPtrs.has_value() ? std::move(*dataPtrs) : extractDataPtrs(value);
73+
74+
currentDevice_ = c10::cuda::current_device();
75+
76+
// Extract them once and cache them for later uses.
77+
dataPtrs_ = std::move(actualDataPtrs);
78+
79+
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
80+
for (const at::DataPtr& data_ptr : dataPtrs_) {
81+
if (data_ptr.device().is_cuda()) {
82+
isCudaDeviceUsed[data_ptr.device().index()] = true;
83+
}
84+
}
85+
86+
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
87+
if (isCudaDeviceUsed[idx]) {
88+
at::cuda::CUDAEvent cudaEvent;
89+
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
90+
cudaEvents_.push_back(std::move(cudaEvent));
91+
}
92+
}
93+
}
94+
95+
std::function<void(void)> CUDAFuture::wrapCallback(
96+
std::function<void(void)> callback) {
97+
return [this, callback{std::move(callback)}]() {
98+
// We'd love to get a stream for all devices, even those that are not used
99+
// by the value, because the callback could use those other devices, but
100+
// unfortunately this could cause a deadlock with NCCL. See
101+
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
102+
// In general, if some devices haven't been used yet, by getting a stream
103+
// for them we'd initialize them, and in addition to causing NCCL to
104+
// misbehaving this also ends up using memory on those devices, which the
105+
// user might not want.
106+
std::vector<at::cuda::CUDAStream> streams;
107+
for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
108+
c10::DeviceIndex idx = cudaEvent.device_index();
109+
// FIXME Should we find a way to allow to change the priority of
110+
// streams?
111+
at::cuda::CUDAStream stream =
112+
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
113+
cudaEvent.block(stream);
114+
streams.push_back(stream);
115+
}
116+
117+
// Use the dedicated callback stream to run callback.
118+
at::cuda::CUDAMultiStreamGuard streamGuard(streams);
119+
120+
// Do not free the underlying data storage of value_ before its
121+
// usage on the stream finishes.
122+
for (const at::DataPtr& data_ptr : dataPtrs_) {
123+
if (data_ptr.device().is_cuda()) {
124+
c10::cuda::CUDACachingAllocator::recordStream(
125+
data_ptr,
126+
at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
127+
}
128+
}
129+
130+
c10::cuda::CUDAGuard deviceGuard(currentDevice_);
131+
132+
callback();
133+
};
134+
}
135+
136+
void CUDAFuture::postWaitHook(const at::IValue& value) {
137+
for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
138+
cudaEvent.block(at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
139+
}
140+
141+
for (const at::DataPtr& data_ptr : dataPtrs_) {
142+
if (data_ptr.device().is_cuda()) {
143+
c10::cuda::CUDACachingAllocator::recordStream(
144+
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
145+
}
146+
}
147+
}
148+
149+
} // namespace cuda
150+
} // namespace at

aten/src/ATen/cuda/CUDAFuture.h

Lines changed: 5 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
11
#pragma once
22

33
#include <functional>
4-
#include <memory>
5-
#include <mutex>
6-
#include <utility>
74
#include <vector>
85

96
#include <ATen/core/ivalue.h>
107
#include <ATen/core/ivalue_inl.h>
11-
#include <ATen/core/jit_type.h>
128
#include <ATen/cuda/CUDAEvent.h>
13-
#include <ATen/cuda/CUDAMultiStreamGuard.h>
14-
#include <c10/core/Allocator.h>
159
#include <c10/core/Device.h>
16-
#include <c10/cuda/CUDACachingAllocator.h>
17-
#include <c10/cuda/CUDAFunctions.h>
18-
#include <c10/cuda/CUDAStream.h>
1910
#include <c10/macros/Export.h>
2011
#include <c10/util/intrusive_ptr.h>
2112

@@ -24,128 +15,22 @@ namespace cuda {
2415

2516
struct TORCH_CUDA_CPP_API CUDAFuture final : at::ivalue::Future {
2617
public:
27-
CUDAFuture(at::TypePtr type) : at::ivalue::Future(std::move(type)) {
28-
// Use current device to initialize currentDevice_. This is necessary
29-
// because preMarkCompletedHook won't be called when the Future contains
30-
// an error. Uninitialized currentDevice_ could lead to crash when used
31-
// in CUDAGuard.
32-
currentDevice_ = c10::cuda::current_device();
33-
}
18+
explicit CUDAFuture(at::TypePtr type);
3419

35-
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
36-
return c10::make_intrusive<CUDAFuture>(std::move(type));
37-
}
20+
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override;
3821

3922
protected:
40-
/**
41-
* The dataPtrs field contains storage pointers of all tensors in the IValue.
42-
* This method records CUDAEvents on participating devices and uses those
43-
* CUDAEvents to synchronize streams when calling postWaitHook().
44-
* If dataPtrs does not have a value, this method will try to inspect the
45-
* given IValue by walking through all subvalues and extracting data pointers
46-
* from CUDA tensors.
47-
*/
4823
void preMarkCompletedHook(
4924
const at::IValue& value,
5025
c10::optional<std::vector<std::reference_wrapper<const at::DataPtr>>>
51-
dataPtrs) override {
52-
// Start by performing all steps that can throw, before setting any field.
53-
std::vector<std::reference_wrapper<const at::DataPtr>> actualDataPtrs =
54-
dataPtrs.has_value() ? std::move(*dataPtrs) : extractDataPtrs(value);
55-
56-
currentDevice_ = c10::cuda::current_device();
57-
58-
// Extract them once and cache them for later uses.
59-
dataPtrs_ = std::move(actualDataPtrs);
60-
61-
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
62-
for (const at::DataPtr& data_ptr : dataPtrs_) {
63-
if (data_ptr.device().is_cuda()) {
64-
isCudaDeviceUsed[data_ptr.device().index()] = true;
65-
}
66-
}
67-
68-
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
69-
if (isCudaDeviceUsed[idx]) {
70-
at::cuda::CUDAEvent cudaEvent;
71-
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
72-
cudaEvents_.push_back(std::move(cudaEvent));
73-
}
74-
}
75-
}
26+
dataPtrs) override;
7627

7728
std::function<void(void)> wrapCallback(
78-
std::function<void(void)> callback) override {
79-
return [this, callback{std::move(callback)}]() {
80-
// We'd love to get a stream for all devices, even those that are not used
81-
// by the value, because the callback could use those other devices, but
82-
// unfortunately this could cause a deadlock with NCCL. See
83-
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
84-
// In general, if some devices haven't been used yet, by getting a stream
85-
// for them we'd initialize them, and in addition to causing NCCL to
86-
// misbehaving this also ends up using memory on those devices, which the
87-
// user might not want.
88-
std::vector<at::cuda::CUDAStream> streams;
89-
for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
90-
c10::DeviceIndex idx = cudaEvent.device_index();
91-
// FIXME Should we find a way to allow to change the priority of
92-
// streams?
93-
at::cuda::CUDAStream stream =
94-
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
95-
cudaEvent.block(stream);
96-
streams.push_back(stream);
97-
}
98-
99-
// Use the dedicated callback stream to run callback.
100-
at::cuda::CUDAMultiStreamGuard streamGuard(streams);
101-
102-
// Do not free the underlying data storage of value_ before its
103-
// usage on the stream finishes.
104-
for (const at::DataPtr& data_ptr : dataPtrs_) {
105-
if (data_ptr.device().is_cuda()) {
106-
c10::cuda::CUDACachingAllocator::recordStream(
107-
data_ptr,
108-
at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
109-
}
110-
}
29+
std::function<void(void)> callback) override;
11130

112-
c10::cuda::CUDAGuard deviceGuard(currentDevice_);
113-
114-
callback();
115-
};
116-
}
117-
118-
void postWaitHook(const at::IValue& value) override {
119-
for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
120-
cudaEvent.block(at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
121-
}
122-
123-
for (const at::DataPtr& data_ptr : dataPtrs_) {
124-
if (data_ptr.device().is_cuda()) {
125-
c10::cuda::CUDACachingAllocator::recordStream(
126-
data_ptr,
127-
at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
128-
}
129-
}
130-
}
31+
void postWaitHook(const at::IValue& value) override;
13132

13233
private:
133-
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
134-
const at::IValue& value) {
135-
at::IValue::HashAliasedIValues sub_values;
136-
// Prefer getSubValues() over visit() as the latter is a silent no-op for
137-
// some unsupported types, whereas the former at least fails loudly.
138-
value.getSubValues(sub_values);
139-
140-
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
141-
for (const at::IValue& sub_value : sub_values) {
142-
if (sub_value.isTensor()) {
143-
data_ptrs.emplace_back(sub_value.toTensor().storage().data_ptr());
144-
}
145-
}
146-
return data_ptrs;
147-
}
148-
14934
// The device that was current when markCompleted was called, which we'll
15035
// restore when invoking callbacks.
15136
c10::DeviceIndex currentDevice_;

0 commit comments

Comments
 (0)