forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPythonTorchFunctionTLS.cpp
50 lines (38 loc) · 1.5 KB
/
PythonTorchFunctionTLS.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <ATen/PythonTorchFunctionTLS.h>
#include <c10/core/TensorImpl.h>
namespace at {
namespace impl {
static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
pythonTorchFunctionState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
TORCH_CHECK(pythonTorchFunctionState.stack_.size() > 0, "trying to pop from empty mode stack");
const auto out = pythonTorchFunctionState.stack_.back();
pythonTorchFunctionState.stack_.pop_back();
return out;
}
const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
return pythonTorchFunctionState.stack_[idx];
}
int64_t PythonTorchFunctionTLS::stack_len() {
return pythonTorchFunctionState.stack_.size();
}
void PythonTorchFunctionTLS::set_disabled(bool disabled) {
pythonTorchFunctionState.disabled_ = disabled;
}
bool PythonTorchFunctionTLS::is_disabled() {
return pythonTorchFunctionState.disabled_;
}
void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
pythonTorchFunctionState = state;
}
const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
return pythonTorchFunctionState;
}
bool torch_function_mode_enabled() {
return PythonTorchFunctionTLS::stack_len() > 0;
}
} // namespace impl
} // namespace at