4
4
// This source code is licensed under the BSD-style license found in the
5
5
// LICENSE file in the root directory of this source tree.
6
6
7
+ #pragma once
8
+
7
9
#include < torch/types.h>
8
10
#include < memory>
9
11
#include < mutex>
@@ -27,7 +29,7 @@ class Cache {
27
29
public:
28
30
using element_type = std::unique_ptr<T, D>;
29
31
30
- Cache (int capacity) : capacity_(capacity) {}
32
+ explicit Cache (int capacity) : capacity_(capacity) {}
31
33
32
34
// Adds an object to the cache if the cache has capacity. Returns true
33
35
// if object was added and false otherwise.
@@ -56,8 +58,9 @@ bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
56
58
template <typename T, typename D>
57
59
typename Cache<T, D>::element_type Cache<T, D>::get() {
58
60
std::scoped_lock lock (mutex_);
59
- if (cache_.empty ())
61
+ if (cache_.empty ()) {
60
62
return nullptr ;
63
+ }
61
64
62
65
element_type obj = std::move (cache_.back ());
63
66
cache_.pop_back ();
@@ -92,7 +95,15 @@ class PerGpuCache {
92
95
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
93
96
};
94
97
95
- torch::DeviceIndex getNonNegativeDeviceIndex (const torch::Device& device) {
98
+ // Note: this function is inline for convenience, not performance. Because the
99
+ // rest of this file is template functions, they must all be defined in this
100
+ // header. This function is not a template function, and should, in principle,
101
+ // be defined in a .cpp file to preserve the One Definition Rule. That's
102
+ // annoying for such a small amount of code, so we just inline it. If this file
103
+ // grows, and there are more such functions, we should break them out into a
104
+ // .cpp file.
105
+ inline torch::DeviceIndex getNonNegativeDeviceIndex (
106
+ const torch::Device& device) {
96
107
torch::DeviceIndex deviceIndex = device.index ();
97
108
// For single GPU machines libtorch returns -1 for the device index. So for
98
109
// that case we set the device index to 0. That's used in per-gpu cache
0 commit comments