-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
218 lines (180 loc) · 6.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import logging
from argparse import Namespace
import torch
logger = logging.getLogger(__name__)
REGISTRIES = {}
def setup_registry(registry_name: str, base_class=None, default=None, required=False):
assert registry_name.startswith("--")
registry_name = registry_name[2:].replace("-", "_")
REGISTRY = {}
REGISTRY_CLASS_NAMES = set()
# maintain a registry of all registries
if registry_name in REGISTRIES:
return # registry already exists
REGISTRIES[registry_name] = {
"registry" : REGISTRY,
"default" : default,
}
def build_x(args, *extra_args, **extra_kwargs):
choice = getattr(args, registry_name, None)
if choice is None:
if required:
raise ValueError("{} is required!".format(registry_name))
return None
cls = REGISTRY[choice]
if hasattr(cls, "build_" + registry_name):
builder = getattr(cls, "build_" + registry_name)
else:
builder = cls
return builder(args, *extra_args, **extra_kwargs)
def register_x(name):
def register_x_cls(cls):
if name in REGISTRY:
raise ValueError(
"Cannot register duplicate {} ({})".format(registry_name, name)
)
if cls.__name__ in REGISTRY_CLASS_NAMES:
raise ValueError(
"Cannot register {} with duplicate class name ({})".format(
registry_name, cls.__name__
)
)
if base_class is not None and not issubclass(cls, base_class):
raise ValueError(
"{} must extend {}".format(cls.__name__, base_class.__name__)
)
REGISTRY[name] = cls
return cls
return register_x_cls
return build_x, register_x, REGISTRY
def item(tensor):
if hasattr(tensor, "item"):
return tensor.item()
if hasattr(tensor, "__getitem__"):
return tensor[0]
return tensor
def get_rng_state():
state = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
state["cuda_rng_state"] = torch.cuda.get_rng_state()
return state
def set_rng_state(state):
torch.set_rng_state(state["torch_rng_state"])
if torch.cuda.is_available():
torch.cuda.set_rng_state(state["cuda_rng_state"])
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
self.rng_state = get_rng_state()
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.distributed.is_initialized():
torch.cuda.manual_seed_all(seed)
def __enter__(self):
return self
def __exit__(self, *exc):
set_rng_state(self.rng_state)
def has_parameters(module):
try:
next(module.parameters())
return True
except StopIteration:
return False
class CudaEnvironment(object):
def __init__(self):
cur_device = torch.cuda.current_device()
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
self.name = prop.name
self.major = prop.major
self.minor = prop.minor
self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
@staticmethod
def pretty_print_cuda_env_list(cuda_env_list):
"""
Given a list of CudaEnvironments, pretty print them
"""
num_workers = len(cuda_env_list)
center = "CUDA environments for all {} workers".format(num_workers)
banner_len = 40 - len(center) // 2
first_line = "*" * banner_len + center + "*" * banner_len
logger.info(first_line)
for r, env in enumerate(cuda_env_list):
logger.info(
"rank {:3d} ".format(r)
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
+ "name = {:40s}".format(env.name)
)
logger.info(first_line)
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(sample)
def move_to_cuda(sample, device=None):
device = device or torch.cuda.current_device()
def _move_to_cuda(tensor):
# non_blocking is ignored if tensor is not pinned, so we can always set
# to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
return tensor.to(device=device, non_blocking=True)
return apply_to_sample(_move_to_cuda, sample)
def move_to_cpu(sample):
def _move_to_cpu(tensor):
# PyTorch has poor support for half tensors (float16) on CPU.
# Move any such tensors to float32
if tensor.dtype in {torch.bfloat16, torch.float16}:
tensor = tensor.to(dtype=torch.float32)
return tensor.cpu()
return apply_to_sample(_move_to_cpu, sample)
@torch.no_grad()
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
grads = [
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert")
]
expert_grads = [
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert")
]
if len(grads) == 1:
total_norm = torch.norm(grads[0], p = 2, dtype = torch.float32)
else:
if torch.cuda.is_available:
device = torch.cuda.current_device()
else:
device = torch.device("cpu")
total_norm = torch.norm(
torch.stack(
[torch.norm(g, p = 2, dtype = torch.float32).to(device) for g in grads]
)
)
if aggregate_norm_fn is not None:
total_norm = aggregate_norm_fn(total_norm)
if max_norm > 0:
max_norm = float(max_norm)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max = 1)
for g in grads + expert_grads:
g.mul_(clip_coef)
return total_norm