Skip to content

Commit 04eff96

Browse files
authored
Align CPUHashMap and CUDAHashMap implementations (#393)
1 parent ef72b03 commit 04eff96

File tree

10 files changed

+166
-174
lines changed

10 files changed

+166
-174
lines changed

benchmark/classes/hash_map.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import torch
66

7-
from pyg_lib.classes import HashMap
7+
import pyg_lib # noqa
88

99
if __name__ == '__main__':
1010
parser = argparse.ArgumentParser()
@@ -30,6 +30,13 @@
3030
query2 = torch.randperm(args.num_queries, device=args.device)
3131
query2 = query2[:args.num_queries]
3232

33+
if key1.is_cpu:
34+
HashMap = torch.classes.pyg.CPUHashMap
35+
elif key1.is_cuda:
36+
HashMap = torch.classes.pyg.CUDAHashMap
37+
else:
38+
raise NotImplementedError(f"Unsupported device '{args.device}'")
39+
3340
t_init = t_get = 0
3441
for i in range(num_warmups + num_steps):
3542
torch.cuda.synchronize()
@@ -55,7 +62,7 @@
5562
t_start = time.perf_counter()
5663
hash_map = torch.full((args.num_keys, ), fill_value=-1,
5764
dtype=torch.long, device=args.device)
58-
hash_map[key2] = torch.arange(args.num_keys)
65+
hash_map[key2] = torch.arange(args.num_keys, device=args.device)
5966
torch.cuda.synchronize()
6067
if i >= num_warmups:
6168
t_init += time.perf_counter() - t_start
@@ -85,7 +92,7 @@
8592
if i >= num_warmups:
8693
t_get += time.perf_counter() - t_start
8794

88-
print(f' Pandas Init: {t_init / num_steps:.4f}s')
89-
print(f' Pandas Get: {t_get / num_steps:.4f}s')
95+
print(f' Pandas Init: {t_init / num_steps:.4f}s')
96+
print(f' Pandas Get: {t_get / num_steps:.4f}s')
9097

91-
assert out1.equal(torch.tensor(out3))
98+
assert out1.equal(torch.tensor(out3))

pyg_lib/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def load_library(lib_name: str) -> None:
3434
load_library('libpyg')
3535

3636
import pyg_lib.ops # noqa
37-
import pyg_lib.classes # noqa
3837
import pyg_lib.partition # noqa
3938
import pyg_lib.sampler # noqa
4039

pyg_lib/classes/__init__.py

-18
This file was deleted.

pyg_lib/csrc/classes/cpu/hash_map.cpp

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Parallel.h>
3+
#include <parallel_hashmap/phmap.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace classes {
8+
9+
namespace {
10+
11+
#define DISPATCH_CASE_KEY(...) \
12+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
13+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
14+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
15+
16+
#define DISPATCH_KEY(TYPE, NAME, ...) \
17+
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__))
18+
19+
struct HashMapImpl {
20+
virtual ~HashMapImpl() = default;
21+
virtual at::Tensor get(const at::Tensor& query) = 0;
22+
virtual at::Tensor keys() = 0;
23+
};
24+
25+
template <typename KeyType>
26+
struct CPUHashMapImpl : HashMapImpl {
27+
public:
28+
using ValueType = int64_t;
29+
30+
CPUHashMapImpl(const at::Tensor& key) {
31+
map_.reserve(key.numel());
32+
33+
const auto key_data = key.data_ptr<KeyType>();
34+
35+
const auto num_threads = at::get_num_threads();
36+
const auto grain_size =
37+
std::max((key.numel() + num_threads - 1) / num_threads,
38+
at::internal::GRAIN_SIZE);
39+
40+
at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) {
41+
for (int64_t i = beg; i < end; ++i) {
42+
const auto [iterator, inserted] = map_.insert({key_data[i], i});
43+
TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'.");
44+
}
45+
});
46+
}
47+
48+
at::Tensor get(const at::Tensor& query) override {
49+
const auto options =
50+
query.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
51+
const auto out = at::empty({query.numel()}, options);
52+
const auto query_data = query.data_ptr<KeyType>();
53+
const auto out_data = out.data_ptr<ValueType>();
54+
55+
const auto num_threads = at::get_num_threads();
56+
const auto grain_size =
57+
std::max((query.numel() + num_threads - 1) / num_threads,
58+
at::internal::GRAIN_SIZE);
59+
60+
at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) {
61+
for (int64_t i = b; i < e; ++i) {
62+
const auto it = map_.find(query_data[i]);
63+
out_data[i] = (it != map_.end()) ? it->second : -1;
64+
}
65+
});
66+
67+
return out;
68+
}
69+
70+
at::Tensor keys() override {
71+
const auto size = static_cast<int64_t>(map_.size());
72+
73+
at::Tensor key;
74+
if (std::is_same<KeyType, int16_t>::value) {
75+
key = at::empty({size}, at::TensorOptions().dtype(at::kShort));
76+
} else if (std::is_same<KeyType, int32_t>::value) {
77+
key = at::empty({size}, at::TensorOptions().dtype(at::kInt));
78+
} else {
79+
key = at::empty({size}, at::TensorOptions().dtype(at::kLong));
80+
}
81+
const auto key_data = key.data_ptr<KeyType>();
82+
83+
for (const auto& pair : map_) { // No efficient multi-threading possible :(
84+
key_data[pair.second] = pair.first;
85+
}
86+
87+
return key;
88+
}
89+
90+
private:
91+
phmap::parallel_flat_hash_map<
92+
KeyType,
93+
ValueType,
94+
phmap::priv::hash_default_hash<KeyType>,
95+
phmap::priv::hash_default_eq<KeyType>,
96+
phmap::priv::Allocator<std::pair<const KeyType, ValueType>>,
97+
12,
98+
std::mutex>
99+
map_;
100+
};
101+
102+
struct CPUHashMap : torch::CustomClassHolder {
103+
public:
104+
CPUHashMap(const at::Tensor& key) {
105+
at::TensorArg key_arg{key, "key", 0};
106+
at::CheckedFrom c{"CPUHashMap.init"};
107+
at::checkDeviceType(c, key, at::DeviceType::CPU);
108+
at::checkDim(c, key_arg, 1);
109+
at::checkContiguous(c, key_arg);
110+
111+
DISPATCH_KEY(key.scalar_type(), "cpu_hash_map_init", [&] {
112+
map_ = std::make_unique<CPUHashMapImpl<scalar_t>>(key);
113+
});
114+
}
115+
116+
at::Tensor get(const at::Tensor& query) {
117+
at::TensorArg query_arg{query, "query", 0};
118+
at::CheckedFrom c{"CPUHashMap.get"};
119+
at::checkDeviceType(c, query, at::DeviceType::CPU);
120+
at::checkDim(c, query_arg, 1);
121+
at::checkContiguous(c, query_arg);
122+
123+
return map_->get(query);
124+
}
125+
126+
at::Tensor keys() { return map_->keys(); }
127+
128+
private:
129+
std::unique_ptr<HashMapImpl> map_;
130+
};
131+
132+
} // namespace
133+
134+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
135+
m.class_<CPUHashMap>("CPUHashMap")
136+
.def(torch::init<at::Tensor&>())
137+
.def("get", &CPUHashMap::get)
138+
.def("keys", &CPUHashMap::keys)
139+
.def_pickle(
140+
// __getstate__
141+
[](const c10::intrusive_ptr<CPUHashMap>& self) -> at::Tensor {
142+
return self->keys();
143+
},
144+
// __setstate__
145+
[](const at::Tensor& state) -> c10::intrusive_ptr<CPUHashMap> {
146+
return c10::make_intrusive<CPUHashMap>(state);
147+
});
148+
}
149+
150+
} // namespace classes
151+
} // namespace pyg

pyg_lib/csrc/classes/cpu/hash_map_impl.h

-68
This file was deleted.

pyg_lib/csrc/classes/hash_map.cpp

-47
This file was deleted.

pyg_lib/csrc/classes/hash_map.h

-19
This file was deleted.

pyg_lib/csrc/classes/hash_map_impl.h

-14
This file was deleted.

test/csrc/classes/test_hash_map.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include <ATen/ATen.h>
22
#include <gtest/gtest.h>
33

4-
#include "pyg_lib/csrc/classes/hash_map.h"
4+
#include "pyg_lib/csrc/classes/cpu/hash_map.cpp"
55

66
TEST(HashMapTest, BasicAssertions) {
77
auto options = at::TensorOptions().dtype(at::kLong);
88
auto key = at::tensor({0, 10, 30, 20}, options);
99

10-
auto map = pyg::classes::HashMap(key);
10+
auto map = pyg::classes::CPUHashMap(key);
11+
EXPECT_TRUE(at::equal(map.keys(), key));
1112

1213
auto query = at::tensor({30, 10, 20, 40}, options);
1314
auto expected = at::tensor({2, 1, 3, -1}, options);

0 commit comments

Comments
 (0)