Skip to content

Commit ec0e81e

Browse files
committed
Update typing in other scheduler classes. Fix spacing in cosine typing. Add some basic scheduler unit tests
1 parent 8038db6 commit ec0e81e

File tree

9 files changed

+619
-87
lines changed

9 files changed

+619
-87
lines changed

tests/test_scheduler.py

Lines changed: 534 additions & 0 deletions
Large diffs are not rendered by default.

timm/scheduler/cosine_lr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def __init__(
3535
cycle_mul: float = 1.,
3636
cycle_decay: float = 1.,
3737
cycle_limit: int = 1,
38-
warmup_t: int=0,
39-
warmup_lr_init: float=0.,
40-
warmup_prefix: bool=False,
41-
t_in_epochs: bool=True,
42-
noise_range_t: Union[List[int], Tuple[int, int], int, None]=None,
43-
noise_pct: float=0.67,
44-
noise_std: float=1.0,
45-
noise_seed: int=42,
46-
k_decay: float=1.0,
47-
initialize: bool=True,
38+
warmup_t: int = 0,
39+
warmup_lr_init: float = 0.,
40+
warmup_prefix: bool = False,
41+
t_in_epochs: bool = True,
42+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
43+
noise_pct: float = 0.67,
44+
noise_std: float = 1.0,
45+
noise_seed: int = 42,
46+
k_decay: float = 1.0,
47+
initialize: bool = True,
4848
) -> None:
4949
super().__init__(
5050
optimizer,

timm/scheduler/multistep_lr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import bisect
77
from timm.scheduler.scheduler import Scheduler
8-
from typing import List
8+
from typing import List, Tuple, Union
99

1010
class MultiStepLRScheduler(Scheduler):
1111
"""
@@ -16,15 +16,15 @@ def __init__(
1616
optimizer: torch.optim.Optimizer,
1717
decay_t: List[int],
1818
decay_rate: float = 1.,
19-
warmup_t=0,
20-
warmup_lr_init=0,
21-
warmup_prefix=True,
22-
t_in_epochs=True,
23-
noise_range_t=None,
24-
noise_pct=0.67,
25-
noise_std=1.0,
26-
noise_seed=42,
27-
initialize=True,
19+
warmup_t: int = 0,
20+
warmup_lr_init: float = 0.,
21+
warmup_prefix: bool = True,
22+
t_in_epochs: bool = True,
23+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
24+
noise_pct: float = 0.67,
25+
noise_std: float = 1.0,
26+
noise_seed: int = 42,
27+
initialize: bool = True,
2828
) -> None:
2929
super().__init__(
3030
optimizer,

timm/scheduler/plateau_lr.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Hacked together by / Copyright 2020 Ross Wightman
66
"""
77
import torch
8-
from typing import List, Optional
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
from .scheduler import Scheduler
1111

@@ -15,23 +15,22 @@ class PlateauLRScheduler(Scheduler):
1515

1616
def __init__(
1717
self,
18-
optimizer,
19-
decay_rate=0.1,
20-
patience_t=10,
21-
verbose=True,
22-
threshold=1e-4,
23-
cooldown_t=0,
24-
warmup_t=0,
25-
warmup_lr_init=0,
26-
lr_min=0,
27-
mode='max',
28-
noise_range_t=None,
29-
noise_type='normal',
30-
noise_pct=0.67,
31-
noise_std=1.0,
32-
noise_seed=None,
33-
initialize=True,
34-
):
18+
optimizer: torch.optim.Optimizer,
19+
decay_rate: float = 0.1,
20+
patience_t: int = 10,
21+
threshold: float = 1e-4,
22+
cooldown_t: int = 0,
23+
warmup_t: int = 0,
24+
warmup_lr_init: float = 0.,
25+
lr_min: float = 0.,
26+
mode: str = 'max',
27+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
28+
noise_type: str = 'normal',
29+
noise_pct: float = 0.67,
30+
noise_std: float = 1.0,
31+
noise_seed: Optional[int] = None,
32+
initialize: bool = True,
33+
) -> None:
3534
super().__init__(
3635
optimizer,
3736
'lr',
@@ -47,11 +46,10 @@ def __init__(
4746
self.optimizer,
4847
patience=patience_t,
4948
factor=decay_rate,
50-
verbose=verbose,
5149
threshold=threshold,
5250
cooldown=cooldown_t,
5351
mode=mode,
54-
min_lr=lr_min
52+
min_lr=lr_min,
5553
)
5654

5755
self.warmup_t = warmup_t
@@ -63,19 +61,19 @@ def __init__(
6361
self.warmup_steps = [1 for _ in self.base_values]
6462
self.restore_lr = None
6563

66-
def state_dict(self):
64+
def state_dict(self) -> Dict[str, Any]:
6765
return {
6866
'best': self.lr_scheduler.best,
6967
'last_epoch': self.lr_scheduler.last_epoch,
7068
}
7169

72-
def load_state_dict(self, state_dict):
70+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
7371
self.lr_scheduler.best = state_dict['best']
7472
if 'last_epoch' in state_dict:
7573
self.lr_scheduler.last_epoch = state_dict['last_epoch']
7674

7775
# override the base class step fn completely
78-
def step(self, epoch, metric=None):
76+
def step(self, epoch: int, metric: Optional[float] = None) -> None:
7977
if epoch <= self.warmup_t:
8078
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
8179
super().update_groups(lrs)
@@ -88,15 +86,15 @@ def step(self, epoch, metric=None):
8886

8987
# step the base scheduler if metric given
9088
if metric is not None:
91-
self.lr_scheduler.step(metric, epoch)
89+
self.lr_scheduler.step(metric)
9290

9391
if self._is_apply_noise(epoch):
9492
self._apply_noise(epoch)
9593

9694
def step_update(self, num_updates: int, metric: Optional[float] = None):
9795
return None
9896

99-
def _apply_noise(self, epoch):
97+
def _apply_noise(self, epoch: int) -> None:
10098
noise = self._calculate_noise(epoch)
10199

102100
# apply the noise on top of previous LR, cache the old value so we can restore for normal

timm/scheduler/poly_lr.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import math
88
import logging
9-
from typing import List
9+
from typing import List, Tuple, Union
1010

1111
import torch
1212

@@ -31,16 +31,16 @@ def __init__(
3131
cycle_mul: float = 1.,
3232
cycle_decay: float = 1.,
3333
cycle_limit: int = 1,
34-
warmup_t=0,
35-
warmup_lr_init=0,
36-
warmup_prefix=False,
37-
t_in_epochs=True,
38-
noise_range_t=None,
39-
noise_pct=0.67,
40-
noise_std=1.0,
41-
noise_seed=42,
42-
k_decay=1.0,
43-
initialize=True,
34+
warmup_t: int = 0,
35+
warmup_lr_init: float = 0.,
36+
warmup_prefix: bool = False,
37+
t_in_epochs: bool = True,
38+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
39+
noise_pct: float = 0.67,
40+
noise_std: float = 1.0,
41+
noise_seed: int = 42,
42+
k_decay: float = 1.0,
43+
initialize: bool = True,
4444
) -> None:
4545
super().__init__(
4646
optimizer,

timm/scheduler/scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from abc import ABC
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Tuple, Union
44

55
import torch
66

@@ -29,11 +29,11 @@ def __init__(
2929
optimizer: torch.optim.Optimizer,
3030
param_group_field: str,
3131
t_in_epochs: bool = True,
32-
noise_range_t=None,
33-
noise_type='normal',
34-
noise_pct=0.67,
35-
noise_std=1.0,
36-
noise_seed=None,
32+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
33+
noise_type: str = 'normal',
34+
noise_pct: float = 0.67,
35+
noise_std: float = 1.0,
36+
noise_seed: Optional[int] = None,
3737
initialize: bool = True,
3838
) -> None:
3939
self.optimizer = optimizer
@@ -81,14 +81,14 @@ def step(self, epoch: int, metric: Optional[float] = None) -> None:
8181
values = self._add_noise(values, epoch)
8282
self.update_groups(values)
8383

84-
def step_update(self, num_updates: int, metric: Optional[float] = None):
84+
def step_update(self, num_updates: int, metric: Optional[float] = None) -> None:
8585
self.metric = metric
8686
values = self._get_values(num_updates, on_epoch=False)
8787
if values is not None:
8888
values = self._add_noise(values, num_updates)
8989
self.update_groups(values)
9090

91-
def update_groups(self, values):
91+
def update_groups(self, values: Union[float, List[float]]) -> None:
9292
if not isinstance(values, (list, tuple)):
9393
values = [values] * len(self.optimizer.param_groups)
9494
for param_group, value in zip(self.optimizer.param_groups, values):
@@ -97,13 +97,13 @@ def update_groups(self, values):
9797
else:
9898
param_group[self.param_group_field] = value
9999

100-
def _add_noise(self, lrs, t):
100+
def _add_noise(self, lrs: List[float], t: int) -> List[float]:
101101
if self._is_apply_noise(t):
102102
noise = self._calculate_noise(t)
103103
lrs = [v + v * noise for v in lrs]
104104
return lrs
105105

106-
def _is_apply_noise(self, t) -> bool:
106+
def _is_apply_noise(self, t: int) -> bool:
107107
"""Return True if scheduler in noise range."""
108108
apply_noise = False
109109
if self.noise_range_t is not None:

timm/scheduler/scheduler_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_scheduler_v2(
6969
cooldown_epochs: int = 0,
7070
patience_epochs: int = 10,
7171
decay_rate: float = 0.1,
72-
min_lr: float = 0,
72+
min_lr: float = 0.,
7373
warmup_lr: float = 1e-5,
7474
warmup_epochs: int = 0,
7575
warmup_prefix: bool = False,

timm/scheduler/step_lr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import math
88
import torch
9-
from typing import List
9+
from typing import List, Tuple, Union
1010

1111

1212
from .scheduler import Scheduler
@@ -21,15 +21,15 @@ def __init__(
2121
optimizer: torch.optim.Optimizer,
2222
decay_t: float,
2323
decay_rate: float = 1.,
24-
warmup_t=0,
25-
warmup_lr_init=0,
26-
warmup_prefix=True,
27-
t_in_epochs=True,
28-
noise_range_t=None,
29-
noise_pct=0.67,
30-
noise_std=1.0,
31-
noise_seed=42,
32-
initialize=True,
24+
warmup_t: int = 0,
25+
warmup_lr_init: float = 0.,
26+
warmup_prefix: bool = True,
27+
t_in_epochs: bool = True,
28+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
29+
noise_pct: float = 0.67,
30+
noise_std: float = 1.0,
31+
noise_seed: int = 42,
32+
initialize: bool = True,
3333
) -> None:
3434
super().__init__(
3535
optimizer,

timm/scheduler/tanh_lr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99
import numpy as np
1010
import torch
11-
from typing import List
11+
from typing import List, Tuple, Union
1212

1313
from .scheduler import Scheduler
1414

@@ -32,15 +32,15 @@ def __init__(
3232
cycle_mul: float = 1.,
3333
cycle_decay: float = 1.,
3434
cycle_limit: int = 1,
35-
warmup_t=0,
36-
warmup_lr_init=0,
37-
warmup_prefix=False,
38-
t_in_epochs=True,
39-
noise_range_t=None,
40-
noise_pct=0.67,
41-
noise_std=1.0,
42-
noise_seed=42,
43-
initialize=True,
35+
warmup_t: int = 0,
36+
warmup_lr_init: float = 0.,
37+
warmup_prefix: bool = False,
38+
t_in_epochs: bool = True,
39+
noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
40+
noise_pct: float = 0.67,
41+
noise_std: float = 1.0,
42+
noise_seed: int = 42,
43+
initialize: bool = True,
4444
) -> None:
4545
super().__init__(
4646
optimizer,

0 commit comments

Comments
 (0)