From 74afbb6e6f45aa7f4fe46de84252780d4a1ccb0d Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Wed, 20 Sep 2023 16:14:36 +1200 Subject: [PATCH] Update PackNet to reset optimizer between tasks --- avalanche/models/packnet.py | 40 ++++++++++++++++++++------- tests/models/test_packnet.py | 52 +++++++++++++++++++++++++----------- 2 files changed, 67 insertions(+), 25 deletions(-) diff --git a/avalanche/models/packnet.py b/avalanche/models/packnet.py index d05076951..70eb4d82e 100644 --- a/avalanche/models/packnet.py +++ b/avalanche/models/packnet.py @@ -445,7 +445,7 @@ def __init__(self, wrappee: nn.Module) -> None: :param wrappee: The module to wrap """ super().__init__() - self.wrappee = PackNetModel.wrap(wrappee) + self.wrappee: nn.Module = PackNetModel.wrap(wrappee) def _pn_apply(self, func: t.Callable[["PackNetModel"], None]): """Apply a function to all child PackNetModules @@ -498,7 +498,7 @@ class PackNetPlugin(BaseSGDPlugin): def __init__( self, post_prune_epochs: int, - prune_proportion: float = 0.5, + prune_proportion: t.Union[float, t.Callable[[int], float], t.List[float]] = 0.5, ): """The PackNetPlugin calls PackNet's pruning and freezing procedures at the appropriate times. @@ -506,17 +506,30 @@ def __init__( :param post_prune_epochs: The number of epochs to finetune the model after pruning the parameters. Must be less than the number of training epochs. - :param prune_proportion: The proportion of parameters to prune each - durring each task. Must be between 0 and 1. + :param prune_proportion: The proportion of parameters to prune + during each task. Can be a float, a list of floats, or a function + that takes the task id and returns a float. Each value must be + between 0 and 1. """ super().__init__() self.post_prune_epochs = post_prune_epochs self.total_epochs: Union[int, None] = None - self.prune_proportion = prune_proportion - assert 0 <= self.prune_proportion <= 1, ( - f"`prune_proportion` must be between 0 and 1, got " - f"{self.prune_proportion}" - ) + + self.prune_proportion: t.Callable[[int], float] + if isinstance(prune_proportion, float): + assert 0 <= self.prune_proportion <= 1, ( + f"`prune_proportion` must be between 0 and 1, got " + f"{self.prune_proportion}" + ) + self.prune_proportion = lambda _: prune_proportion + elif isinstance(prune_proportion, list): + assert all(0 <= x <= 1 for x in prune_proportion), ( + "all values in `prune_proportion` must be between 0 and 1," + f" got {prune_proportion}" + ) + self.prune_proportion = lambda i: prune_proportion[i] + else: + self.prune_proportion = prune_proportion def before_training(self, strategy: "BaseSGDTemplate", *args, **kwargs): assert isinstance( @@ -539,6 +552,13 @@ def before_training(self, strategy: "BaseSGDTemplate", *args, **kwargs): f"Strategy has only {self.total_epochs} training epochs." ) + def before_training_exp(self, strategy: "BaseSGDTemplate", *args, **kwargs): + # Reset the optimizer to prevent momentum from affecting the pruned + # parameters + strategy.optimizer = strategy.optimizer.__class__( + strategy.model.parameters(), **strategy.optimizer.defaults + ) + def before_training_epoch(self, strategy: "BaseSGDTemplate", *args, **kwargs): """When the initial training phase is over, prune the model and transition to the post-pruning phase. @@ -547,7 +567,7 @@ def before_training_epoch(self, strategy: "BaseSGDTemplate", *args, **kwargs): model = self._get_model(strategy) if epoch == (self.total_epochs - self.post_prune_epochs): - model.prune(self.prune_proportion) + model.prune(self.prune_proportion(strategy.clock.train_exp_counter)) def after_training_exp(self, strategy: "Template", *args, **kwargs): """After each experience, commit the model so that the next experience diff --git a/tests/models/test_packnet.py b/tests/models/test_packnet.py index 357111be5..1b0caff40 100644 --- a/tests/models/test_packnet.py +++ b/tests/models/test_packnet.py @@ -1,7 +1,7 @@ import unittest from avalanche.models.packnet import PackNetModel, packnet_simple_mlp from avalanche.training.supervised.strategy_wrappers import PackNet -from torch.optim import SGD +from torch.optim import SGD, Adam import torch import os @@ -10,15 +10,7 @@ class TestPackNet(unittest.TestCase): - _EXPECTATIONS = { - "Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000": 0.75, - "Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001": 0.45, - "Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002": 0.33, - "Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003": 0.0886, - "Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004": 0.0, - } - - def test_PackNetPlugin(self): + def _test_PackNetPlugin(self, expectations, optimizer_constructor, lr): torch.manual_seed(0) if "USE_GPU" in os.environ: use_gpu = os.environ["USE_GPU"].lower() in ["true"] @@ -33,7 +25,7 @@ def test_PackNetPlugin(self): hidden_size=20, ) model = construct_model() - optimizer = SGD(model.parameters(), lr=0.1) + optimizer = optimizer_constructor(model.parameters(), lr=lr) strategy = PackNet( model, prune_proportion=0.5, @@ -51,11 +43,15 @@ def test_PackNetPlugin(self): # Train for i, experience in enumerate(scenario.train_stream): strategy.train(experience) + # Assert that the model achieves the expected accuracy for each task + # the model has trained on so far. PackNet should not degrade + # performance on previous tasks. + self.assert_eval(expectations, strategy.eval(scenario.test_stream), i) # Store the model output for each task task_ouputs.append(model.forward(x_test, t_test * i)) # Check that the model achieves the expected accuracy - self.assert_eval(strategy.eval(scenario.test_stream)) + self.assert_eval(expectations, strategy.eval(scenario.test_stream), 5) # Verify the model can be saved and loaded from a state dict new_model = construct_model() @@ -67,13 +63,39 @@ def test_PackNetPlugin(self): strategy.model = new_model # Check that the loaded model achieves the expected accuracy - self.assert_eval(strategy.eval(scenario.test_stream)) + self.assert_eval(expectations, strategy.eval(scenario.test_stream), 5) # Ensure that given the same inputs, the model produces the same outputs for i, task_out in enumerate(task_ouputs): out = model.forward(x_test, t_test * i) self.assertTrue(torch.allclose(out, task_out)) + def test_PackNet_adam(self): + self._test_PackNetPlugin( + [ + ("Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000", 0.500), + ("Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001", 0.410), + ("Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002", 0.134), + ("Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003", 0.240), + ("Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004", 0.163), + ], + optimizer_constructor=Adam, + lr=0.2, + ) + + def test_PackNet_sgd(self): + self._test_PackNetPlugin( + [ + ("Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000", 0.75), + ("Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001", 0.49), + ("Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002", 0.099), + ("Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003", 0.14), + ("Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004", 0.0), + ], + optimizer_constructor=SGD, + lr=0.1, + ) + def test_unsupported_exception(self): """Expect an exception when trying to wrap an unsupported module""" @@ -88,8 +110,8 @@ def __init__(self): with self.assertRaises(ValueError): PackNetModel(torch.nn.BatchNorm2d(10)) - def assert_eval(self, last_eval): - for metric, value in self._EXPECTATIONS.items(): + def assert_eval(self, expectations, last_eval, task_id): + for metric, value in expectations[:task_id]: self.assertAlmostEqual(last_eval[metric], value, places=2)