Skip to content

Commit eb48b31

Browse files
mortzurfacebook-github-bot
authored andcommitted
Device Resident Tensors - API & Framework (pytorch#3745)
Summary: Taking over pytorch#3671, but spinning out the API and Glow-core level changes associated with the DRT plan in pytorch#3629. This does not implement DRT support on any device. Documentation: See pytorch#3629. Pull Request resolved: pytorch#3745 Test Plan: Ran tests, added two simple new sanity checks to DeviceManagerTest. The first `DeviceResidentTensors` should run only for backends that support resident tensors (none currently). The second `CanHandleDeviceResidentTensors` should run on all devices. Differential Revision: D18378905 Pulled By: nickgg fbshipit-source-id: 887c290dae5a6b9b75e9b41a415958d499bc5402
1 parent ec46f24 commit eb48b31

File tree

21 files changed

+464
-66
lines changed

21 files changed

+464
-66
lines changed

include/glow/Backends/DeviceManager.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define GLOW_BACKENDS_DEVICEMANAGER_H
1818

1919
#include "glow/Backend/CompiledFunction.h"
20+
#include "glow/Base/DeviceTensorTransferManager.h"
2021
#include "glow/ExecutionContext/ExecutionContext.h"
2122
#include "glow/Graph/Graph.h"
2223
#include "glow/Runtime/RuntimeTypes.h"
@@ -42,7 +43,7 @@ using ReadyCBTy = std::function<void(const Module *, Error)>;
4243
using FunctionMapTy = std::map<std::string, CompiledFunction *>;
4344

4445
/// Interface managing a specific instance of a device.
45-
class DeviceManager {
46+
class DeviceManager : public DeviceTensorTransferManager {
4647
protected:
4748
/// Configuration object for the device.
4849
DeviceConfig config_;
@@ -162,6 +163,33 @@ class DeviceManager {
162163
/// \returns the DeviceInfo for this device containing peak limits for
163164
/// compute and bandwidths (used in partitioning).
164165
virtual DeviceInfo getDeviceInfo() const { return DeviceInfo(); }
166+
167+
/// Copies the contents of \p tensor from the host to the \p location
168+
/// address on this device. Updates the tensor residency info.
169+
virtual void transferToDevice(Tensor &tensor, void *locationContext,
170+
std::function<void(Error)> resultCB =
171+
[](Error) {}) {
172+
DCHECK("Not Implemented");
173+
resultCB(MAKE_ERR(ErrorValue::ErrorCode::DEVICE_FEATURE_NOT_SUPPORTED,
174+
"Direct transfer not supported on this device"));
175+
}
176+
177+
/// Copies the device buffer associated with \p tensor to the host.
178+
/// The tensor must be resident on this device. If \p release is true,
179+
/// frees the device memory. Updates the tensor residency info.
180+
virtual void transferFromDevice(Tensor &tensor, bool release = true,
181+
std::function<void(Error)> resultCB =
182+
[](Error) {}) {
183+
DCHECK("Not Implemented");
184+
resultCB(MAKE_ERR(ErrorValue::ErrorCode::DEVICE_FEATURE_NOT_SUPPORTED,
185+
"Direct transfer not supported on this device"));
186+
}
187+
188+
/// Releases the device buffer associated with \p tensor.
189+
virtual bool releaseDeviceTensor(void *locationContext) {
190+
DCHECK("Not Implemented");
191+
return false;
192+
}
165193
};
166194

167195
} // namespace runtime
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H
17+
#define GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H
18+
19+
#include "glow/Base/Tensor.h"
20+
#include "glow/Support/Error.h"
21+
22+
#include <functional>
23+
24+
namespace glow {
25+
26+
class Tensor;
27+
28+
class DeviceTensorTransferManager {
29+
public:
30+
virtual ~DeviceTensorTransferManager() {}
31+
/// Copies the contents of \p tensor from the host to the \p location address
32+
/// on this device. Updates the tensor residency info.
33+
virtual void transferToDevice(Tensor &tensor, void *locationContext = nullptr,
34+
std::function<void(Error)> resultCB =
35+
[](Error) {}) = 0;
36+
37+
/// Copies the device buffer associated with \p tensor to the host.
38+
/// The tensor must be resident on this device. If \p release is true, frees
39+
/// the device memory. Updates the tensor residency info.
40+
virtual void transferFromDevice(Tensor &tensor, bool release = true,
41+
std::function<void(Error)> resultCB =
42+
[](Error) {}) = 0;
43+
44+
/// Releases the device buffer associated with \p tensor.
45+
virtual bool releaseDeviceTensor(void *locationContext) = 0;
46+
};
47+
48+
} // namespace glow
49+
50+
#endif // GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H

include/glow/Base/Tensor.h

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cassert>
2121
#include <vector>
2222

23+
#include "glow/Base/DeviceTensorTransferManager.h"
2324
#include "glow/Base/Type.h"
2425
#include "glow/Support/Compiler.h"
2526
#include "glow/Support/Memory.h"
@@ -48,6 +49,71 @@ void genericTranspose(const Tensor *src, Tensor *dest,
4849
/// returned dims. For example, input {2,1,4} would result in {2,1,4,1,1,1}.
4950
ShapeVector expandDimsToMax(llvm::ArrayRef<size_t> currDims);
5051

52+
namespace runtime {
53+
class DeviceManager;
54+
}
55+
56+
/// Holds information regarding whether this Tensor exists in a device-specific
57+
/// form, either resident or specific for a device, and what device holds it.
58+
class DeviceResidencyInfo final {
59+
enum class TensorResidency {
60+
Host,
61+
Device,
62+
};
63+
64+
// A pointer to the device manager of the device on which the tensor
65+
// resides.
66+
DeviceTensorTransferManager *deviceManager_{nullptr};
67+
/// The residency status of the tensor.
68+
TensorResidency tensorResidency_{TensorResidency::Host};
69+
// A pointer to a context structure, containing the required info to access
70+
// tensor data and perform transfers.
71+
void *locationContext_{nullptr};
72+
73+
public:
74+
DeviceResidencyInfo()
75+
: deviceManager_(nullptr), tensorResidency_(TensorResidency::Host),
76+
locationContext_(nullptr) {}
77+
78+
/// Move ctor.
79+
DeviceResidencyInfo(DeviceResidencyInfo &&other) = delete;
80+
81+
/// Move assignment operator.
82+
DeviceResidencyInfo &operator=(DeviceResidencyInfo &&other) = delete;
83+
84+
~DeviceResidencyInfo() {
85+
// If a tensor is device resident, let its device manager free the device
86+
// buffer.
87+
if (isDeviceResident()) {
88+
deviceManager_->releaseDeviceTensor(locationContext_);
89+
}
90+
}
91+
92+
/// Removes all device specific state.
93+
void clear() {
94+
deviceManager_ = nullptr;
95+
locationContext_ = nullptr;
96+
tensorResidency_ = TensorResidency::Host;
97+
}
98+
99+
/// \returns true if this Tensor is resident or specific for a device.
100+
bool isDeviceResident() const {
101+
assert((tensorResidency_ == TensorResidency::Host || deviceManager_) &&
102+
"Device resident tensor must have an assigned device manager.");
103+
return tensorResidency_ == TensorResidency::Device;
104+
}
105+
106+
/// \returns the DeviceManager this tensor is resident on, if any.
107+
DeviceTensorTransferManager *getDeviceManager() const {
108+
return deviceManager_;
109+
}
110+
111+
/// \returns the device specific location context for a resident Tensor.
112+
void *getLocationContext() const { return locationContext_; }
113+
114+
friend class Tensor;
115+
};
116+
51117
/// A class that represents a contiguous n-dimensional array (a tensor).
52118
class Tensor final {
53119
public:
@@ -71,6 +137,10 @@ class Tensor final {
71137
/// The TensorPool that is managing this Tensor (if any).
72138
TensorPool *tensorPool_{nullptr};
73139

140+
/// The device residency info accosiated with the tensor.
141+
std::shared_ptr<DeviceResidencyInfo> residencyInfoP_{
142+
new DeviceResidencyInfo()};
143+
74144
/// Size in bytes of the unpadded region memory. This is useful communicating
75145
/// the actual size of the data, this allows for copying only inputs and not
76146
/// padding to the device.
@@ -119,6 +189,7 @@ class Tensor final {
119189
/// Set the content of the tensor to zero. If \p resetFusedScalesOffsets, then
120190
/// fused scales/offsets will be set to 1.0/0.0 as well.
121191
void zero(bool resetFusedScalesOffsets = false) {
192+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
122193
size_t size = actualSize();
123194
// Quantized tensors should go to their offset.
124195
switch (type_.getElementType()) {
@@ -298,7 +369,7 @@ class Tensor final {
298369
unownedTensor.isUnowned_ = true;
299370
unownedTensor.type_ = Type::newShape(getType(), dims);
300371
unownedTensor.unpaddedSize_ = unpaddedSize_;
301-
372+
unownedTensor.residencyInfoP_ = residencyInfoP_;
302373
if (offsets.size() == 0) {
303374
assert(actualSize() == unownedTensor.actualSize() &&
304375
"The size of the unowned tensor "
@@ -321,6 +392,7 @@ class Tensor final {
321392
/// element to start a subview from.
322393
Tensor getOwnedSlice(llvm::ArrayRef<size_t> dims,
323394
llvm::ArrayRef<size_t> offsets = {}) const {
395+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
324396
return getUnowned(dims, offsets).clone();
325397
}
326398

@@ -341,6 +413,7 @@ class Tensor final {
341413

342414
/// Assigns a new shape to the tensor and allocates a new buffer.
343415
void reset(const Type &T) {
416+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
344417
// If the new size is identical to the allocated size then there is no need
345418
// to re-allocate the buffer.
346419
if (type_ == T && getData()) {
@@ -390,6 +463,7 @@ class Tensor final {
390463
std::swap(isUnowned_, other.isUnowned_);
391464
std::swap(tensorPool_, other.tensorPool_);
392465
std::swap(unpaddedSize_, other.unpaddedSize_);
466+
std::swap(residencyInfoP_, other.residencyInfoP_);
393467
}
394468

395469
/// Move assignment operator.
@@ -399,6 +473,7 @@ class Tensor final {
399473
std::swap(isUnowned_, other.isUnowned_);
400474
std::swap(tensorPool_, other.tensorPool_);
401475
std::swap(unpaddedSize_, other.unpaddedSize_);
476+
std::swap(residencyInfoP_, other.residencyInfoP_);
402477
return *this;
403478
}
404479

@@ -429,6 +504,14 @@ class Tensor final {
429504
/// elements exceeding allowed error; maximum error and location found; etc.).
430505
bool isEqual(const Tensor &other, float allowedError = 0.0001,
431506
bool verbose = true) const {
507+
if (isDeviceResident()) {
508+
if (!other.isDeviceResident()) {
509+
return false;
510+
}
511+
512+
return getDeviceManager() == other.getDeviceManager() &&
513+
getLocationContext() == other.getLocationContext();
514+
}
432515
return isEqualImpl(other, /*isBitwise=*/false, allowedError, verbose);
433516
}
434517

@@ -513,6 +596,7 @@ class Tensor final {
513596

514597
/// Update the content and type of the tensor from the tensor \p t.
515598
void assign(const Tensor *t) {
599+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
516600
assert(this != t && "Copying to self");
517601
reset(t);
518602
size_t bufferSize = type_.getSizeInBytes();
@@ -521,6 +605,7 @@ class Tensor final {
521605

522606
/// Update the raw data of the tensor from the tensor \p t.
523607
void copyRawFrom(const Tensor *t) {
608+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
524609
assert(this != t && "Copying to self");
525610
assert(actualSize() == t->actualSize());
526611
assert(getElementType() == t->getElementType() && "Invalid element type");
@@ -531,6 +616,7 @@ class Tensor final {
531616
/// Update the content of the tensor with a slice from tensor \p t. A slice
532617
/// is one index from the first dimension of the tensor.
533618
void copySlice(const Tensor *t, size_t slice) {
619+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
534620
auto dim = t->dims().slice(1);
535621
(void)dim;
536622
assert(dim == dims() && "Invalid slice size");
@@ -546,6 +632,7 @@ class Tensor final {
546632
/// The copying operation may overlap the end of the tensor \p t one or more
547633
/// times. This means that the data in the input tensor may be duplicated.
548634
void copyConsecutiveSlices(const Tensor *t, size_t startSliceIdx) {
635+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
549636
auto onceSliceDim = t->dims().slice(1);
550637
(void)onceSliceDim;
551638
assert(onceSliceDim == dims().slice(1) && "Invalid slice size");
@@ -571,6 +658,7 @@ class Tensor final {
571658
/// and cast them to DestElemType in this.
572659
template <typename DestElemType, typename SrcElemType>
573660
void copyWithCast(const Tensor *t) {
661+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
574662
static_assert(!std::is_same<DestElemType, SrcElemType>::value,
575663
"Use copyRawFrom instead");
576664
assert(this != t && "Copying to self");
@@ -599,11 +687,13 @@ class Tensor final {
599687
/// Transpose the tensor \p src into the empty tensor \p dest. Shuffle the
600688
/// axis based on the list \p shuffle, where each element is the src index.
601689
void transpose(Tensor *dest, llvm::ArrayRef<unsigned_t> shuffle) const {
690+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
602691
genericTranspose(this, dest, shuffle);
603692
}
604693

605694
/// Create a new copy of the current tensor.
606695
Tensor clone() const {
696+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
607697
Tensor slice;
608698
slice.assign(this);
609699
return slice;
@@ -612,6 +702,40 @@ class Tensor final {
612702
/// Return the raw unsafe pointer to the tensor payload.
613703
char *getUnsafePtr() const { return getData(); }
614704

705+
/// \returns true if tensor data is stored on a device
706+
bool isDeviceResident() const { return residencyInfoP_->isDeviceResident(); }
707+
708+
/// Update device residency info with new device manager and context
709+
void moveToDevice(DeviceTensorTransferManager *deviceManager,
710+
void *locationContext);
711+
712+
/// If device resident, copy Tensor contents back to host memory and release
713+
/// associated device memory.
714+
void ensureOnHost();
715+
716+
/// \returns the pointer to the device manager where the tensor resides.
717+
DeviceTensorTransferManager *getDeviceManager() const {
718+
assert(residencyInfoP_->isDeviceResident() &&
719+
"Tensor must be device resident");
720+
return residencyInfoP_->getDeviceManager();
721+
}
722+
723+
/// \returns the pointer to the location context of where the tensor resides.
724+
void *getLocationContext() const {
725+
assert(residencyInfoP_->isDeviceResident() &&
726+
"Tensor must be device resident");
727+
return residencyInfoP_->getLocationContext();
728+
}
729+
730+
/// Clears DeviceResidencyInfo.
731+
/// Note that this does not affect the associated DeviceManager or device
732+
/// memory.
733+
void clearDeviceResidency() {
734+
assert(residencyInfoP_->isDeviceResident() &&
735+
"Tensor must be device resident");
736+
residencyInfoP_->clear();
737+
}
738+
615739
/// \return a new handle that points and manages this tensor.
616740
template <class ElemTy = float> Handle<ElemTy> getHandle() &;
617741

@@ -623,19 +747,22 @@ class Tensor final {
623747
private:
624748
/// \returns a pointer to the raw data, of type \p ElemTy.
625749
template <class ElemTy> ElemTy *getRawDataPointer() {
750+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
626751
assert(type_.isType<ElemTy>() && "Asking for the wrong ptr type.");
627752
return reinterpret_cast<ElemTy *>(data_);
628753
}
629754

630755
/// \returns a const pointer to the raw data, of type \p ElemTy.
631756
template <class ElemTy> const ElemTy *getRawDataPointer() const {
757+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
632758
assert(type_.isType<ElemTy>() && "Asking for the wrong ptr type.");
633759
return reinterpret_cast<const ElemTy *>(data_);
634760
}
635761

636762
template <class ElemTy>
637763
bool isEqualImpl(const Tensor &other, float allowedError,
638764
bool verbose) const {
765+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
639766
auto const *myData = getRawDataPointer<ElemTy>();
640767
auto const *otherData = other.getRawDataPointer<ElemTy>();
641768
double maxFoundError = 0.0;
@@ -668,6 +795,7 @@ class Tensor final {
668795
}
669796

670797
bool isBitwiseEqualImpl(const Tensor &other) const {
798+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
671799
auto const *myData = getUnsafePtr();
672800
auto const *otherData = other.getUnsafePtr();
673801
for (size_t i = 0, e = getSizeInBytes(); i < e; i++) {
@@ -1283,11 +1411,13 @@ template <class ElemTy> class Handle final {
12831411
};
12841412

12851413
template <class ElemTy> Handle<ElemTy> Tensor::getHandle() & {
1414+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
12861415
assert(type_.isType<ElemTy>() && "Getting a handle to the wrong type.");
12871416
return Handle<ElemTy>(this);
12881417
}
12891418

12901419
template <class ElemTy> const Handle<ElemTy> Tensor::getHandle() const & {
1420+
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
12911421
assert(type_.isType<ElemTy>() && "Getting a handle to the wrong type.");
12921422
return Handle<ElemTy>(const_cast<Tensor *>(this));
12931423
}

0 commit comments

Comments
 (0)