|
1 | 1 | #pragma once |
2 | | - |
3 | | -/** |
4 | | - * To register your own tensor types, do in a header file: |
5 | | - * AT_DECLARE_TENSOR_TYPE(MY_TENSOR) |
6 | | - * and in one (!) cpp file: |
7 | | - * AT_DEFINE_TENSOR_TYPE(MY_TENSOR) |
8 | | - * Both must be in the same namespace. |
9 | | - */ |
10 | | - |
11 | | -#include "ATen/core/TensorTypeId.h" |
12 | | -#include "c10/macros/Macros.h" |
13 | | - |
14 | | -#include <atomic> |
15 | | -#include <mutex> |
16 | | -#include <unordered_set> |
17 | | - |
18 | | -namespace at { |
19 | | - |
20 | | -class CAFFE2_API TensorTypeIdCreator final { |
21 | | - public: |
22 | | - TensorTypeIdCreator(); |
23 | | - |
24 | | - at::TensorTypeId create(); |
25 | | - |
26 | | - static constexpr at::TensorTypeId undefined() noexcept { |
27 | | - return TensorTypeId(0); |
28 | | - } |
29 | | - |
30 | | - private: |
31 | | - std::atomic<details::_tensorTypeId_underlyingType> last_id_; |
32 | | - |
33 | | - C10_DISABLE_COPY_AND_ASSIGN(TensorTypeIdCreator); |
34 | | -}; |
35 | | - |
36 | | -class CAFFE2_API TensorTypeIdRegistry final { |
37 | | - public: |
38 | | - TensorTypeIdRegistry(); |
39 | | - |
40 | | - void registerId(at::TensorTypeId id); |
41 | | - void deregisterId(at::TensorTypeId id); |
42 | | - |
43 | | - private: |
44 | | - std::unordered_set<at::TensorTypeId> registeredTypeIds_; |
45 | | - std::mutex mutex_; |
46 | | - |
47 | | - C10_DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistry); |
48 | | -}; |
49 | | - |
50 | | -class CAFFE2_API TensorTypeIds final { |
51 | | - public: |
52 | | - static TensorTypeIds& singleton(); |
53 | | - |
54 | | - at::TensorTypeId createAndRegister(); |
55 | | - void deregister(at::TensorTypeId id); |
56 | | - |
57 | | - static constexpr at::TensorTypeId undefined() noexcept; |
58 | | - |
59 | | - private: |
60 | | - TensorTypeIds(); |
61 | | - |
62 | | - TensorTypeIdCreator creator_; |
63 | | - TensorTypeIdRegistry registry_; |
64 | | - |
65 | | - C10_DISABLE_COPY_AND_ASSIGN(TensorTypeIds); |
66 | | -}; |
67 | | - |
68 | | -inline constexpr at::TensorTypeId TensorTypeIds::undefined() noexcept { |
69 | | - return TensorTypeIdCreator::undefined(); |
70 | | -} |
71 | | - |
72 | | -class CAFFE2_API TensorTypeIdRegistrar final { |
73 | | - public: |
74 | | - TensorTypeIdRegistrar(); |
75 | | - ~TensorTypeIdRegistrar(); |
76 | | - |
77 | | - at::TensorTypeId id() const noexcept; |
78 | | - |
79 | | - private: |
80 | | - at::TensorTypeId id_; |
81 | | - |
82 | | - C10_DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistrar); |
83 | | -}; |
84 | | - |
85 | | -inline at::TensorTypeId TensorTypeIdRegistrar::id() const noexcept { |
86 | | - return id_; |
87 | | -} |
88 | | - |
89 | | -#define AT_DECLARE_TENSOR_TYPE(TensorName) \ |
90 | | - CAFFE2_API at::TensorTypeId TensorName() |
91 | | - |
92 | | -#define AT_DEFINE_TENSOR_TYPE(TensorName) \ |
93 | | - at::TensorTypeId TensorName() { \ |
94 | | - static TensorTypeIdRegistrar registration_raii; \ |
95 | | - return registration_raii.id(); \ |
96 | | - } |
97 | | - |
98 | | -AT_DECLARE_TENSOR_TYPE(UndefinedTensorId); |
99 | | -AT_DECLARE_TENSOR_TYPE(CPUTensorId); // PyTorch/Caffe2 supported |
100 | | -AT_DECLARE_TENSOR_TYPE(CUDATensorId); // PyTorch/Caffe2 supported |
101 | | -AT_DECLARE_TENSOR_TYPE(SparseCPUTensorId); // PyTorch only |
102 | | -AT_DECLARE_TENSOR_TYPE(SparseCUDATensorId); // PyTorch only |
103 | | -AT_DECLARE_TENSOR_TYPE(MKLDNNTensorId); // Caffe2 only |
104 | | -AT_DECLARE_TENSOR_TYPE(OpenGLTensorId); // Caffe2 only |
105 | | -AT_DECLARE_TENSOR_TYPE(OpenCLTensorId); // Caffe2 only |
106 | | -AT_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only |
107 | | -AT_DECLARE_TENSOR_TYPE(HIPTensorId); // Caffe2 only |
108 | | - |
109 | | -} // namespace at |
| 2 | +#include "c10/util/TensorTypeIdRegistration.h" |
0 commit comments