2020#include < cassert>
2121#include < vector>
2222
23+ #include " glow/Base/DeviceTensorTransferManager.h"
2324#include " glow/Base/Type.h"
2425#include " glow/Support/Compiler.h"
2526#include " glow/Support/Memory.h"
@@ -48,6 +49,71 @@ void genericTranspose(const Tensor *src, Tensor *dest,
4849// / returned dims. For example, input {2,1,4} would result in {2,1,4,1,1,1}.
4950ShapeVector expandDimsToMax (llvm::ArrayRef<size_t > currDims);
5051
52+ namespace runtime {
53+ class DeviceManager ;
54+ }
55+
56+ // / Holds information regarding whether this Tensor exists in a device-specific
57+ // / form, either resident or specific for a device, and what device holds it.
58+ class DeviceResidencyInfo final {
59+ enum class TensorResidency {
60+ Host,
61+ Device,
62+ };
63+
64+ // A pointer to the device manager of the device on which the tensor
65+ // resides.
66+ DeviceTensorTransferManager *deviceManager_{nullptr };
67+ // / The residency status of the tensor.
68+ TensorResidency tensorResidency_{TensorResidency::Host};
69+ // A pointer to a context structure, containing the required info to access
70+ // tensor data and perform transfers.
71+ void *locationContext_{nullptr };
72+
73+ public:
74+ DeviceResidencyInfo ()
75+ : deviceManager_(nullptr ), tensorResidency_(TensorResidency::Host),
76+ locationContext_ (nullptr ) {}
77+
78+ // / Move ctor.
79+ DeviceResidencyInfo (DeviceResidencyInfo &&other) = delete;
80+
81+ // / Move assignment operator.
82+ DeviceResidencyInfo &operator =(DeviceResidencyInfo &&other) = delete ;
83+
84+ ~DeviceResidencyInfo () {
85+ // If a tensor is device resident, let its device manager free the device
86+ // buffer.
87+ if (isDeviceResident ()) {
88+ deviceManager_->releaseDeviceTensor (locationContext_);
89+ }
90+ }
91+
92+ // / Removes all device specific state.
93+ void clear () {
94+ deviceManager_ = nullptr ;
95+ locationContext_ = nullptr ;
96+ tensorResidency_ = TensorResidency::Host;
97+ }
98+
99+ // / \returns true if this Tensor is resident or specific for a device.
100+ bool isDeviceResident () const {
101+ assert ((tensorResidency_ == TensorResidency::Host || deviceManager_) &&
102+ " Device resident tensor must have an assigned device manager." );
103+ return tensorResidency_ == TensorResidency::Device;
104+ }
105+
106+ // / \returns the DeviceManager this tensor is resident on, if any.
107+ DeviceTensorTransferManager *getDeviceManager () const {
108+ return deviceManager_;
109+ }
110+
111+ // / \returns the device specific location context for a resident Tensor.
112+ void *getLocationContext () const { return locationContext_; }
113+
114+ friend class Tensor ;
115+ };
116+
51117// / A class that represents a contiguous n-dimensional array (a tensor).
52118class Tensor final {
53119public:
@@ -71,6 +137,10 @@ class Tensor final {
71137 // / The TensorPool that is managing this Tensor (if any).
72138 TensorPool *tensorPool_{nullptr };
73139
140+ // / The device residency info accosiated with the tensor.
141+ std::shared_ptr<DeviceResidencyInfo> residencyInfoP_{
142+ new DeviceResidencyInfo ()};
143+
74144 // / Size in bytes of the unpadded region memory. This is useful communicating
75145 // / the actual size of the data, this allows for copying only inputs and not
76146 // / padding to the device.
@@ -119,6 +189,7 @@ class Tensor final {
119189 // / Set the content of the tensor to zero. If \p resetFusedScalesOffsets, then
120190 // / fused scales/offsets will be set to 1.0/0.0 as well.
121191 void zero (bool resetFusedScalesOffsets = false ) {
192+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
122193 size_t size = actualSize ();
123194 // Quantized tensors should go to their offset.
124195 switch (type_.getElementType ()) {
@@ -298,7 +369,7 @@ class Tensor final {
298369 unownedTensor.isUnowned_ = true ;
299370 unownedTensor.type_ = Type::newShape (getType (), dims);
300371 unownedTensor.unpaddedSize_ = unpaddedSize_;
301-
372+ unownedTensor. residencyInfoP_ = residencyInfoP_;
302373 if (offsets.size () == 0 ) {
303374 assert (actualSize () == unownedTensor.actualSize () &&
304375 " The size of the unowned tensor "
@@ -321,6 +392,7 @@ class Tensor final {
321392 // / element to start a subview from.
322393 Tensor getOwnedSlice (llvm::ArrayRef<size_t > dims,
323394 llvm::ArrayRef<size_t > offsets = {}) const {
395+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
324396 return getUnowned (dims, offsets).clone ();
325397 }
326398
@@ -341,6 +413,7 @@ class Tensor final {
341413
342414 // / Assigns a new shape to the tensor and allocates a new buffer.
343415 void reset (const Type &T) {
416+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
344417 // If the new size is identical to the allocated size then there is no need
345418 // to re-allocate the buffer.
346419 if (type_ == T && getData ()) {
@@ -390,6 +463,7 @@ class Tensor final {
390463 std::swap (isUnowned_, other.isUnowned_ );
391464 std::swap (tensorPool_, other.tensorPool_ );
392465 std::swap (unpaddedSize_, other.unpaddedSize_ );
466+ std::swap (residencyInfoP_, other.residencyInfoP_ );
393467 }
394468
395469 // / Move assignment operator.
@@ -399,6 +473,7 @@ class Tensor final {
399473 std::swap (isUnowned_, other.isUnowned_ );
400474 std::swap (tensorPool_, other.tensorPool_ );
401475 std::swap (unpaddedSize_, other.unpaddedSize_ );
476+ std::swap (residencyInfoP_, other.residencyInfoP_ );
402477 return *this ;
403478 }
404479
@@ -429,6 +504,14 @@ class Tensor final {
429504 // / elements exceeding allowed error; maximum error and location found; etc.).
430505 bool isEqual (const Tensor &other, float allowedError = 0.0001 ,
431506 bool verbose = true ) const {
507+ if (isDeviceResident ()) {
508+ if (!other.isDeviceResident ()) {
509+ return false ;
510+ }
511+
512+ return getDeviceManager () == other.getDeviceManager () &&
513+ getLocationContext () == other.getLocationContext ();
514+ }
432515 return isEqualImpl (other, /* isBitwise=*/ false , allowedError, verbose);
433516 }
434517
@@ -513,6 +596,7 @@ class Tensor final {
513596
514597 // / Update the content and type of the tensor from the tensor \p t.
515598 void assign (const Tensor *t) {
599+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
516600 assert (this != t && " Copying to self" );
517601 reset (t);
518602 size_t bufferSize = type_.getSizeInBytes ();
@@ -521,6 +605,7 @@ class Tensor final {
521605
522606 // / Update the raw data of the tensor from the tensor \p t.
523607 void copyRawFrom (const Tensor *t) {
608+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
524609 assert (this != t && " Copying to self" );
525610 assert (actualSize () == t->actualSize ());
526611 assert (getElementType () == t->getElementType () && " Invalid element type" );
@@ -531,6 +616,7 @@ class Tensor final {
531616 // / Update the content of the tensor with a slice from tensor \p t. A slice
532617 // / is one index from the first dimension of the tensor.
533618 void copySlice (const Tensor *t, size_t slice) {
619+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
534620 auto dim = t->dims ().slice (1 );
535621 (void )dim;
536622 assert (dim == dims () && " Invalid slice size" );
@@ -546,6 +632,7 @@ class Tensor final {
546632 // / The copying operation may overlap the end of the tensor \p t one or more
547633 // / times. This means that the data in the input tensor may be duplicated.
548634 void copyConsecutiveSlices (const Tensor *t, size_t startSliceIdx) {
635+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
549636 auto onceSliceDim = t->dims ().slice (1 );
550637 (void )onceSliceDim;
551638 assert (onceSliceDim == dims ().slice (1 ) && " Invalid slice size" );
@@ -571,6 +658,7 @@ class Tensor final {
571658 // / and cast them to DestElemType in this.
572659 template <typename DestElemType, typename SrcElemType>
573660 void copyWithCast (const Tensor *t) {
661+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
574662 static_assert (!std::is_same<DestElemType, SrcElemType>::value,
575663 " Use copyRawFrom instead" );
576664 assert (this != t && " Copying to self" );
@@ -599,11 +687,13 @@ class Tensor final {
599687 // / Transpose the tensor \p src into the empty tensor \p dest. Shuffle the
600688 // / axis based on the list \p shuffle, where each element is the src index.
601689 void transpose (Tensor *dest, llvm::ArrayRef<unsigned_t > shuffle) const {
690+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
602691 genericTranspose (this , dest, shuffle);
603692 }
604693
605694 // / Create a new copy of the current tensor.
606695 Tensor clone () const {
696+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
607697 Tensor slice;
608698 slice.assign (this );
609699 return slice;
@@ -612,6 +702,40 @@ class Tensor final {
612702 // / Return the raw unsafe pointer to the tensor payload.
613703 char *getUnsafePtr () const { return getData (); }
614704
705+ // / \returns true if tensor data is stored on a device
706+ bool isDeviceResident () const { return residencyInfoP_->isDeviceResident (); }
707+
708+ // / Update device residency info with new device manager and context
709+ void moveToDevice (DeviceTensorTransferManager *deviceManager,
710+ void *locationContext);
711+
712+ // / If device resident, copy Tensor contents back to host memory and release
713+ // / associated device memory.
714+ void ensureOnHost ();
715+
716+ // / \returns the pointer to the device manager where the tensor resides.
717+ DeviceTensorTransferManager *getDeviceManager () const {
718+ assert (residencyInfoP_->isDeviceResident () &&
719+ " Tensor must be device resident" );
720+ return residencyInfoP_->getDeviceManager ();
721+ }
722+
723+ // / \returns the pointer to the location context of where the tensor resides.
724+ void *getLocationContext () const {
725+ assert (residencyInfoP_->isDeviceResident () &&
726+ " Tensor must be device resident" );
727+ return residencyInfoP_->getLocationContext ();
728+ }
729+
730+ // / Clears DeviceResidencyInfo.
731+ // / Note that this does not affect the associated DeviceManager or device
732+ // / memory.
733+ void clearDeviceResidency () {
734+ assert (residencyInfoP_->isDeviceResident () &&
735+ " Tensor must be device resident" );
736+ residencyInfoP_->clear ();
737+ }
738+
615739 // / \return a new handle that points and manages this tensor.
616740 template <class ElemTy = float > Handle<ElemTy> getHandle () &;
617741
@@ -623,19 +747,22 @@ class Tensor final {
623747private:
624748 // / \returns a pointer to the raw data, of type \p ElemTy.
625749 template <class ElemTy > ElemTy *getRawDataPointer () {
750+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
626751 assert (type_.isType <ElemTy>() && " Asking for the wrong ptr type." );
627752 return reinterpret_cast <ElemTy *>(data_);
628753 }
629754
630755 // / \returns a const pointer to the raw data, of type \p ElemTy.
631756 template <class ElemTy > const ElemTy *getRawDataPointer () const {
757+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
632758 assert (type_.isType <ElemTy>() && " Asking for the wrong ptr type." );
633759 return reinterpret_cast <const ElemTy *>(data_);
634760 }
635761
636762 template <class ElemTy >
637763 bool isEqualImpl (const Tensor &other, float allowedError,
638764 bool verbose) const {
765+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
639766 auto const *myData = getRawDataPointer<ElemTy>();
640767 auto const *otherData = other.getRawDataPointer <ElemTy>();
641768 double maxFoundError = 0.0 ;
@@ -668,6 +795,7 @@ class Tensor final {
668795 }
669796
670797 bool isBitwiseEqualImpl (const Tensor &other) const {
798+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
671799 auto const *myData = getUnsafePtr ();
672800 auto const *otherData = other.getUnsafePtr ();
673801 for (size_t i = 0 , e = getSizeInBytes (); i < e; i++) {
@@ -1283,11 +1411,13 @@ template <class ElemTy> class Handle final {
12831411};
12841412
12851413template <class ElemTy > Handle<ElemTy> Tensor::getHandle () & {
1414+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
12861415 assert (type_.isType <ElemTy>() && " Getting a handle to the wrong type." );
12871416 return Handle<ElemTy>(this );
12881417}
12891418
12901419template <class ElemTy > const Handle<ElemTy> Tensor::getHandle () const & {
1420+ assert (!isDeviceResident () && " Tensor must reside on host to access data." );
12911421 assert (type_.isType <ElemTy>() && " Getting a handle to the wrong type." );
12921422 return Handle<ElemTy>(const_cast <Tensor *>(this ));
12931423}
0 commit comments