11#pragma once
22
33#include < functional>
4- #include < memory>
5- #include < mutex>
6- #include < utility>
74#include < vector>
85
96#include < ATen/core/ivalue.h>
107#include < ATen/core/ivalue_inl.h>
11- #include < ATen/core/jit_type.h>
128#include < ATen/cuda/CUDAEvent.h>
13- #include < ATen/cuda/CUDAMultiStreamGuard.h>
14- #include < c10/core/Allocator.h>
159#include < c10/core/Device.h>
16- #include < c10/cuda/CUDACachingAllocator.h>
17- #include < c10/cuda/CUDAFunctions.h>
18- #include < c10/cuda/CUDAStream.h>
1910#include < c10/macros/Export.h>
2011#include < c10/util/intrusive_ptr.h>
2112
@@ -24,128 +15,22 @@ namespace cuda {
2415
2516struct TORCH_CUDA_CPP_API CUDAFuture final : at::ivalue::Future {
2617 public:
27- CUDAFuture (at::TypePtr type) : at::ivalue::Future(std::move(type)) {
28- // Use current device to initialize currentDevice_. This is necessary
29- // because preMarkCompletedHook won't be called when the Future contains
30- // an error. Uninitialized currentDevice_ could lead to crash when used
31- // in CUDAGuard.
32- currentDevice_ = c10::cuda::current_device ();
33- }
18+ explicit CUDAFuture (at::TypePtr type);
3419
35- c10::intrusive_ptr<Future> createInstance (at::TypePtr type) override {
36- return c10::make_intrusive<CUDAFuture>(std::move (type));
37- }
20+ c10::intrusive_ptr<Future> createInstance (at::TypePtr type) override ;
3821
3922 protected:
40- /* *
41- * The dataPtrs field contains storage pointers of all tensors in the IValue.
42- * This method records CUDAEvents on participating devices and uses those
43- * CUDAEvents to synchronize streams when calling postWaitHook().
44- * If dataPtrs does not have a value, this method will try to inspect the
45- * given IValue by walking through all subvalues and extracting data pointers
46- * from CUDA tensors.
47- */
4823 void preMarkCompletedHook (
4924 const at::IValue& value,
5025 c10::optional<std::vector<std::reference_wrapper<const at::DataPtr>>>
51- dataPtrs) override {
52- // Start by performing all steps that can throw, before setting any field.
53- std::vector<std::reference_wrapper<const at::DataPtr>> actualDataPtrs =
54- dataPtrs.has_value () ? std::move (*dataPtrs) : extractDataPtrs (value);
55-
56- currentDevice_ = c10::cuda::current_device ();
57-
58- // Extract them once and cache them for later uses.
59- dataPtrs_ = std::move (actualDataPtrs);
60-
61- std::vector<bool > isCudaDeviceUsed (c10::cuda::device_count (), false );
62- for (const at::DataPtr& data_ptr : dataPtrs_) {
63- if (data_ptr.device ().is_cuda ()) {
64- isCudaDeviceUsed[data_ptr.device ().index ()] = true ;
65- }
66- }
67-
68- for (c10::DeviceIndex idx = 0 ; idx < isCudaDeviceUsed.size (); idx++) {
69- if (isCudaDeviceUsed[idx]) {
70- at::cuda::CUDAEvent cudaEvent;
71- cudaEvent.record (at::cuda::getCurrentCUDAStream (idx));
72- cudaEvents_.push_back (std::move (cudaEvent));
73- }
74- }
75- }
26+ dataPtrs) override ;
7627
7728 std::function<void (void )> wrapCallback (
78- std::function<void (void )> callback) override {
79- return [this , callback{std::move (callback)}]() {
80- // We'd love to get a stream for all devices, even those that are not used
81- // by the value, because the callback could use those other devices, but
82- // unfortunately this could cause a deadlock with NCCL. See
83- // https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
84- // In general, if some devices haven't been used yet, by getting a stream
85- // for them we'd initialize them, and in addition to causing NCCL to
86- // misbehaving this also ends up using memory on those devices, which the
87- // user might not want.
88- std::vector<at::cuda::CUDAStream> streams;
89- for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
90- c10::DeviceIndex idx = cudaEvent.device_index ();
91- // FIXME Should we find a way to allow to change the priority of
92- // streams?
93- at::cuda::CUDAStream stream =
94- at::cuda::getStreamFromPool (/* isHighPriority=*/ false , idx);
95- cudaEvent.block (stream);
96- streams.push_back (stream);
97- }
98-
99- // Use the dedicated callback stream to run callback.
100- at::cuda::CUDAMultiStreamGuard streamGuard (streams);
101-
102- // Do not free the underlying data storage of value_ before its
103- // usage on the stream finishes.
104- for (const at::DataPtr& data_ptr : dataPtrs_) {
105- if (data_ptr.device ().is_cuda ()) {
106- c10::cuda::CUDACachingAllocator::recordStream (
107- data_ptr,
108- at::cuda::getCurrentCUDAStream (data_ptr.device ().index ()));
109- }
110- }
29+ std::function<void (void )> callback) override ;
11130
112- c10::cuda::CUDAGuard deviceGuard (currentDevice_);
113-
114- callback ();
115- };
116- }
117-
118- void postWaitHook (const at::IValue& value) override {
119- for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) {
120- cudaEvent.block (at::cuda::getCurrentCUDAStream (cudaEvent.device_index ()));
121- }
122-
123- for (const at::DataPtr& data_ptr : dataPtrs_) {
124- if (data_ptr.device ().is_cuda ()) {
125- c10::cuda::CUDACachingAllocator::recordStream (
126- data_ptr,
127- at::cuda::getCurrentCUDAStream (data_ptr.device ().index ()));
128- }
129- }
130- }
31+ void postWaitHook (const at::IValue& value) override ;
13132
13233 private:
133- std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs (
134- const at::IValue& value) {
135- at::IValue::HashAliasedIValues sub_values;
136- // Prefer getSubValues() over visit() as the latter is a silent no-op for
137- // some unsupported types, whereas the former at least fails loudly.
138- value.getSubValues (sub_values);
139-
140- std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
141- for (const at::IValue& sub_value : sub_values) {
142- if (sub_value.isTensor ()) {
143- data_ptrs.emplace_back (sub_value.toTensor ().storage ().data_ptr ());
144- }
145- }
146- return data_ptrs;
147- }
148-
14934 // The device that was current when markCompleted was called, which we'll
15035 // restore when invoking callbacks.
15136 c10::DeviceIndex currentDevice_;
0 commit comments