|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/Parallel.h> |
| 3 | +#include <parallel_hashmap/phmap.h> |
| 4 | +#include <torch/library.h> |
| 5 | + |
| 6 | +namespace pyg { |
| 7 | +namespace classes { |
| 8 | + |
| 9 | +namespace { |
| 10 | + |
| 11 | +#define DISPATCH_CASE_KEY(...) \ |
| 12 | + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ |
| 13 | + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ |
| 14 | + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) |
| 15 | + |
| 16 | +#define DISPATCH_KEY(TYPE, NAME, ...) \ |
| 17 | + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__)) |
| 18 | + |
| 19 | +struct HashMapImpl { |
| 20 | + virtual ~HashMapImpl() = default; |
| 21 | + virtual at::Tensor get(const at::Tensor& query) = 0; |
| 22 | + virtual at::Tensor keys() = 0; |
| 23 | +}; |
| 24 | + |
| 25 | +template <typename KeyType> |
| 26 | +struct CPUHashMapImpl : HashMapImpl { |
| 27 | + public: |
| 28 | + using ValueType = int64_t; |
| 29 | + |
| 30 | + CPUHashMapImpl(const at::Tensor& key) { |
| 31 | + map_.reserve(key.numel()); |
| 32 | + |
| 33 | + const auto key_data = key.data_ptr<KeyType>(); |
| 34 | + |
| 35 | + const auto num_threads = at::get_num_threads(); |
| 36 | + const auto grain_size = |
| 37 | + std::max((key.numel() + num_threads - 1) / num_threads, |
| 38 | + at::internal::GRAIN_SIZE); |
| 39 | + |
| 40 | + at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) { |
| 41 | + for (int64_t i = beg; i < end; ++i) { |
| 42 | + const auto [iterator, inserted] = map_.insert({key_data[i], i}); |
| 43 | + TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'."); |
| 44 | + } |
| 45 | + }); |
| 46 | + } |
| 47 | + |
| 48 | + at::Tensor get(const at::Tensor& query) override { |
| 49 | + const auto options = |
| 50 | + query.options().dtype(c10::CppTypeToScalarType<ValueType>::value); |
| 51 | + const auto out = at::empty({query.numel()}, options); |
| 52 | + const auto query_data = query.data_ptr<KeyType>(); |
| 53 | + const auto out_data = out.data_ptr<ValueType>(); |
| 54 | + |
| 55 | + const auto num_threads = at::get_num_threads(); |
| 56 | + const auto grain_size = |
| 57 | + std::max((query.numel() + num_threads - 1) / num_threads, |
| 58 | + at::internal::GRAIN_SIZE); |
| 59 | + |
| 60 | + at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) { |
| 61 | + for (int64_t i = b; i < e; ++i) { |
| 62 | + const auto it = map_.find(query_data[i]); |
| 63 | + out_data[i] = (it != map_.end()) ? it->second : -1; |
| 64 | + } |
| 65 | + }); |
| 66 | + |
| 67 | + return out; |
| 68 | + } |
| 69 | + |
| 70 | + at::Tensor keys() override { |
| 71 | + const auto size = static_cast<int64_t>(map_.size()); |
| 72 | + |
| 73 | + at::Tensor key; |
| 74 | + if (std::is_same<KeyType, int16_t>::value) { |
| 75 | + key = at::empty({size}, at::TensorOptions().dtype(at::kShort)); |
| 76 | + } else if (std::is_same<KeyType, int32_t>::value) { |
| 77 | + key = at::empty({size}, at::TensorOptions().dtype(at::kInt)); |
| 78 | + } else { |
| 79 | + key = at::empty({size}, at::TensorOptions().dtype(at::kLong)); |
| 80 | + } |
| 81 | + const auto key_data = key.data_ptr<KeyType>(); |
| 82 | + |
| 83 | + for (const auto& pair : map_) { // No efficient multi-threading possible :( |
| 84 | + key_data[pair.second] = pair.first; |
| 85 | + } |
| 86 | + |
| 87 | + return key; |
| 88 | + } |
| 89 | + |
| 90 | + private: |
| 91 | + phmap::parallel_flat_hash_map< |
| 92 | + KeyType, |
| 93 | + ValueType, |
| 94 | + phmap::priv::hash_default_hash<KeyType>, |
| 95 | + phmap::priv::hash_default_eq<KeyType>, |
| 96 | + phmap::priv::Allocator<std::pair<const KeyType, ValueType>>, |
| 97 | + 12, |
| 98 | + std::mutex> |
| 99 | + map_; |
| 100 | +}; |
| 101 | + |
| 102 | +struct CPUHashMap : torch::CustomClassHolder { |
| 103 | + public: |
| 104 | + CPUHashMap(const at::Tensor& key) { |
| 105 | + at::TensorArg key_arg{key, "key", 0}; |
| 106 | + at::CheckedFrom c{"CPUHashMap.init"}; |
| 107 | + at::checkDeviceType(c, key, at::DeviceType::CPU); |
| 108 | + at::checkDim(c, key_arg, 1); |
| 109 | + at::checkContiguous(c, key_arg); |
| 110 | + |
| 111 | + DISPATCH_KEY(key.scalar_type(), "cpu_hash_map_init", [&] { |
| 112 | + map_ = std::make_unique<CPUHashMapImpl<scalar_t>>(key); |
| 113 | + }); |
| 114 | + } |
| 115 | + |
| 116 | + at::Tensor get(const at::Tensor& query) { |
| 117 | + at::TensorArg query_arg{query, "query", 0}; |
| 118 | + at::CheckedFrom c{"CPUHashMap.get"}; |
| 119 | + at::checkDeviceType(c, query, at::DeviceType::CPU); |
| 120 | + at::checkDim(c, query_arg, 1); |
| 121 | + at::checkContiguous(c, query_arg); |
| 122 | + |
| 123 | + return map_->get(query); |
| 124 | + } |
| 125 | + |
| 126 | + at::Tensor keys() { return map_->keys(); } |
| 127 | + |
| 128 | + private: |
| 129 | + std::unique_ptr<HashMapImpl> map_; |
| 130 | +}; |
| 131 | + |
| 132 | +} // namespace |
| 133 | + |
| 134 | +TORCH_LIBRARY_FRAGMENT(pyg, m) { |
| 135 | + m.class_<CPUHashMap>("CPUHashMap") |
| 136 | + .def(torch::init<at::Tensor&>()) |
| 137 | + .def("get", &CPUHashMap::get) |
| 138 | + .def("keys", &CPUHashMap::keys) |
| 139 | + .def_pickle( |
| 140 | + // __getstate__ |
| 141 | + [](const c10::intrusive_ptr<CPUHashMap>& self) -> at::Tensor { |
| 142 | + return self->keys(); |
| 143 | + }, |
| 144 | + // __setstate__ |
| 145 | + [](const at::Tensor& state) -> c10::intrusive_ptr<CPUHashMap> { |
| 146 | + return c10::make_intrusive<CPUHashMap>(state); |
| 147 | + }); |
| 148 | +} |
| 149 | + |
| 150 | +} // namespace classes |
| 151 | +} // namespace pyg |
0 commit comments