1
1
import unittest
2
2
3
+ from torch .nn import CrossEntropyLoss
3
4
from torch .optim import SGD
4
5
from torch .utils .data import DataLoader
5
6
6
7
from avalanche .core import Agent
7
- from avalanche .models import SimpleMLP , as_multitask
8
+ from avalanche .models import SimpleMLP , as_multitask , IncrementalClassifier , MTSimpleMLP
9
+ from avalanche .models .dynamic_modules import avalanche_model_adaptation
8
10
from avalanche .models .dynamic_optimizers import DynamicOptimizer
9
11
from avalanche .training import MaskedCrossEntropy
10
12
from tests .unit_tests_utils import get_fast_benchmark
@@ -17,7 +19,7 @@ def test_dynamic_optimizer(self):
17
19
agent .loss = MaskedCrossEntropy ()
18
20
agent .model = as_multitask (SimpleMLP (input_size = 6 ), "classifier" )
19
21
opt = SGD (agent .model .parameters (), lr = 0.001 )
20
- agent .opt = DynamicOptimizer (opt )
22
+ agent .opt = DynamicOptimizer (opt , agent . model , verbose = False )
21
23
22
24
for exp in bm .train_stream :
23
25
agent .model .train ()
@@ -32,3 +34,97 @@ def test_dynamic_optimizer(self):
32
34
l .backward ()
33
35
agent .opt .step ()
34
36
agent .post_adapt (exp )
37
+
38
+ def init_scenario (self , multi_task = False ):
39
+ if multi_task :
40
+ model = MTSimpleMLP (input_size = 6 , hidden_size = 10 )
41
+ else :
42
+ model = SimpleMLP (input_size = 6 , hidden_size = 10 )
43
+ model .classifier = IncrementalClassifier (10 , 1 )
44
+ criterion = CrossEntropyLoss ()
45
+ benchmark = get_fast_benchmark (use_task_labels = multi_task )
46
+ return model , criterion , benchmark
47
+
48
+ def _is_param_in_optimizer (self , param , optimizer ):
49
+ for group in optimizer .param_groups :
50
+ for curr_p in group ["params" ]:
51
+ if hash (curr_p ) == hash (param ):
52
+ return True
53
+ return False
54
+
55
+ def _is_param_in_optimizer_group (self , param , optimizer ):
56
+ for group_idx , group in enumerate (optimizer .param_groups ):
57
+ for curr_p in group ["params" ]:
58
+ if hash (curr_p ) == hash (param ):
59
+ return group_idx
60
+ return None
61
+
62
+ def test_optimizer_groups_clf_til (self ):
63
+ """
64
+ Tests the automatic assignation of new
65
+ MultiHead parameters to the optimizer
66
+ """
67
+ model , criterion , benchmark = self .init_scenario (multi_task = True )
68
+
69
+ g1 = []
70
+ g2 = []
71
+ for n , p in model .named_parameters ():
72
+ if "classifier" in n :
73
+ g1 .append (p )
74
+ else :
75
+ g2 .append (p )
76
+
77
+ agent = Agent ()
78
+ agent .model = model
79
+ optimizer = SGD ([{"params" : g1 , "lr" : 0.1 }, {"params" : g2 , "lr" : 0.05 }])
80
+ agent .optimizer = DynamicOptimizer (optimizer , model = model , verbose = False )
81
+
82
+ for experience in benchmark .train_stream :
83
+ avalanche_model_adaptation (model , experience )
84
+ agent .optimizer .pre_adapt (agent , experience )
85
+
86
+ for n , p in model .named_parameters ():
87
+ assert self ._is_param_in_optimizer (p , agent .optimizer .optim )
88
+ if "classifier" in n :
89
+ self .assertEqual (
90
+ self ._is_param_in_optimizer_group (p , agent .optimizer .optim ), 0
91
+ )
92
+ else :
93
+ self .assertEqual (
94
+ self ._is_param_in_optimizer_group (p , agent .optimizer .optim ), 1
95
+ )
96
+
97
+ def test_optimizer_groups_clf_cil (self ):
98
+ """
99
+ Tests the automatic assignation of new
100
+ IncrementalClassifier parameters to the optimizer
101
+ """
102
+ model , criterion , benchmark = self .init_scenario (multi_task = False )
103
+
104
+ g1 = []
105
+ g2 = []
106
+ for n , p in model .named_parameters ():
107
+ if "classifier" in n :
108
+ g1 .append (p )
109
+ else :
110
+ g2 .append (p )
111
+
112
+ agent = Agent ()
113
+ agent .model = model
114
+ optimizer = SGD ([{"params" : g1 , "lr" : 0.1 }, {"params" : g2 , "lr" : 0.05 }])
115
+ agent .optimizer = DynamicOptimizer (optimizer , model )
116
+
117
+ for experience in benchmark .train_stream :
118
+ avalanche_model_adaptation (model , experience )
119
+ agent .optimizer .pre_adapt (agent , experience )
120
+
121
+ for n , p in model .named_parameters ():
122
+ assert self ._is_param_in_optimizer (p , agent .optimizer .optim )
123
+ if "classifier" in n :
124
+ self .assertEqual (
125
+ self ._is_param_in_optimizer_group (p , agent .optimizer .optim ), 0
126
+ )
127
+ else :
128
+ self .assertEqual (
129
+ self ._is_param_in_optimizer_group (p , agent .optimizer .optim ), 1
130
+ )
0 commit comments