55Hacked together by / Copyright 2020 Ross Wightman
66"""
77import torch
8- from typing import List , Optional
8+ from typing import Any , Dict , List , Optional , Tuple , Union
99
1010from .scheduler import Scheduler
1111
@@ -15,23 +15,22 @@ class PlateauLRScheduler(Scheduler):
1515
1616 def __init__ (
1717 self ,
18- optimizer ,
19- decay_rate = 0.1 ,
20- patience_t = 10 ,
21- verbose = True ,
22- threshold = 1e-4 ,
23- cooldown_t = 0 ,
24- warmup_t = 0 ,
25- warmup_lr_init = 0 ,
26- lr_min = 0 ,
27- mode = 'max' ,
28- noise_range_t = None ,
29- noise_type = 'normal' ,
30- noise_pct = 0.67 ,
31- noise_std = 1.0 ,
32- noise_seed = None ,
33- initialize = True ,
34- ):
18+ optimizer : torch .optim .Optimizer ,
19+ decay_rate : float = 0.1 ,
20+ patience_t : int = 10 ,
21+ threshold : float = 1e-4 ,
22+ cooldown_t : int = 0 ,
23+ warmup_t : int = 0 ,
24+ warmup_lr_init : float = 0. ,
25+ lr_min : float = 0. ,
26+ mode : str = 'max' ,
27+ noise_range_t : Union [List [int ], Tuple [int , int ], int , None ] = None ,
28+ noise_type : str = 'normal' ,
29+ noise_pct : float = 0.67 ,
30+ noise_std : float = 1.0 ,
31+ noise_seed : Optional [int ] = None ,
32+ initialize : bool = True ,
33+ ) -> None :
3534 super ().__init__ (
3635 optimizer ,
3736 'lr' ,
@@ -47,11 +46,10 @@ def __init__(
4746 self .optimizer ,
4847 patience = patience_t ,
4948 factor = decay_rate ,
50- verbose = verbose ,
5149 threshold = threshold ,
5250 cooldown = cooldown_t ,
5351 mode = mode ,
54- min_lr = lr_min
52+ min_lr = lr_min ,
5553 )
5654
5755 self .warmup_t = warmup_t
@@ -63,19 +61,19 @@ def __init__(
6361 self .warmup_steps = [1 for _ in self .base_values ]
6462 self .restore_lr = None
6563
66- def state_dict (self ):
64+ def state_dict (self ) -> Dict [ str , Any ] :
6765 return {
6866 'best' : self .lr_scheduler .best ,
6967 'last_epoch' : self .lr_scheduler .last_epoch ,
7068 }
7169
72- def load_state_dict (self , state_dict ) :
70+ def load_state_dict (self , state_dict : Dict [ str , Any ]) -> None :
7371 self .lr_scheduler .best = state_dict ['best' ]
7472 if 'last_epoch' in state_dict :
7573 self .lr_scheduler .last_epoch = state_dict ['last_epoch' ]
7674
7775 # override the base class step fn completely
78- def step (self , epoch , metric = None ):
76+ def step (self , epoch : int , metric : Optional [ float ] = None ) -> None :
7977 if epoch <= self .warmup_t :
8078 lrs = [self .warmup_lr_init + epoch * s for s in self .warmup_steps ]
8179 super ().update_groups (lrs )
@@ -88,15 +86,15 @@ def step(self, epoch, metric=None):
8886
8987 # step the base scheduler if metric given
9088 if metric is not None :
91- self .lr_scheduler .step (metric , epoch )
89+ self .lr_scheduler .step (metric )
9290
9391 if self ._is_apply_noise (epoch ):
9492 self ._apply_noise (epoch )
9593
9694 def step_update (self , num_updates : int , metric : Optional [float ] = None ):
9795 return None
9896
99- def _apply_noise (self , epoch ) :
97+ def _apply_noise (self , epoch : int ) -> None :
10098 noise = self ._calculate_noise (epoch )
10199
102100 # apply the noise on top of previous LR, cache the old value so we can restore for normal
0 commit comments