forked from dingo-actual/dropgrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dropgrad_scheduler.py
57 lines (43 loc) · 1.62 KB
/
test_dropgrad_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from dropgrad import (
CosineAnnealingDropRateScheduler,
LinearDropRateScheduler,
StepDropRateScheduler,
)
def test_linear_drop_rate_scheduler():
initial_drop_rate = 0.5
final_drop_rate = 0.1
num_steps = 100
scheduler = LinearDropRateScheduler(initial_drop_rate, final_drop_rate, num_steps)
assert scheduler.get_drop_rate() == initial_drop_rate
for step in range(num_steps):
drop_rate = scheduler.get_drop_rate()
assert initial_drop_rate >= drop_rate >= final_drop_rate
scheduler.step()
assert scheduler.get_drop_rate() == final_drop_rate
def test_cosine_annealing_drop_rate_scheduler():
initial_drop_rate = 0.5
final_drop_rate = 0.1
num_steps = 100
scheduler = CosineAnnealingDropRateScheduler(
initial_drop_rate, final_drop_rate, num_steps
)
assert scheduler.get_drop_rate() == initial_drop_rate
for step in range(num_steps):
drop_rate = scheduler.get_drop_rate()
assert final_drop_rate <= drop_rate <= initial_drop_rate
scheduler.step()
assert scheduler.get_drop_rate() == final_drop_rate
def test_step_drop_rate_scheduler():
initial_drop_rate = 0.5
drop_rate_schedule = {50: 0.3, 80: 0.1}
scheduler = StepDropRateScheduler(initial_drop_rate, drop_rate_schedule)
assert scheduler.get_drop_rate() == initial_drop_rate
for step in range(100):
drop_rate = scheduler.get_drop_rate()
if step < 50:
assert drop_rate == initial_drop_rate
elif step < 80:
assert drop_rate == 0.3
else:
assert drop_rate == 0.1
scheduler.step()