Skip to content

Commit

Permalink
Update PackNet to reset optimizer between tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
tachyonicClock committed Sep 20, 2023
1 parent 8475017 commit 74afbb6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 25 deletions.
40 changes: 30 additions & 10 deletions avalanche/models/packnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __init__(self, wrappee: nn.Module) -> None:
:param wrappee: The module to wrap
"""
super().__init__()
self.wrappee = PackNetModel.wrap(wrappee)
self.wrappee: nn.Module = PackNetModel.wrap(wrappee)

def _pn_apply(self, func: t.Callable[["PackNetModel"], None]):
"""Apply a function to all child PackNetModules
Expand Down Expand Up @@ -498,25 +498,38 @@ class PackNetPlugin(BaseSGDPlugin):
def __init__(
self,
post_prune_epochs: int,
prune_proportion: float = 0.5,
prune_proportion: t.Union[float, t.Callable[[int], float], t.List[float]] = 0.5,
):
"""The PackNetPlugin calls PackNet's pruning and freezing procedures at
the appropriate times.
:param post_prune_epochs: The number of epochs to finetune the model
after pruning the parameters. Must be less than the number of
training epochs.
:param prune_proportion: The proportion of parameters to prune each
durring each task. Must be between 0 and 1.
:param prune_proportion: The proportion of parameters to prune
during each task. Can be a float, a list of floats, or a function
that takes the task id and returns a float. Each value must be
between 0 and 1.
"""
super().__init__()
self.post_prune_epochs = post_prune_epochs
self.total_epochs: Union[int, None] = None
self.prune_proportion = prune_proportion
assert 0 <= self.prune_proportion <= 1, (
f"`prune_proportion` must be between 0 and 1, got "
f"{self.prune_proportion}"
)

self.prune_proportion: t.Callable[[int], float]
if isinstance(prune_proportion, float):
assert 0 <= self.prune_proportion <= 1, (
f"`prune_proportion` must be between 0 and 1, got "
f"{self.prune_proportion}"
)
self.prune_proportion = lambda _: prune_proportion
elif isinstance(prune_proportion, list):
assert all(0 <= x <= 1 for x in prune_proportion), (
"all values in `prune_proportion` must be between 0 and 1,"
f" got {prune_proportion}"
)
self.prune_proportion = lambda i: prune_proportion[i]
else:
self.prune_proportion = prune_proportion

def before_training(self, strategy: "BaseSGDTemplate", *args, **kwargs):
assert isinstance(
Expand All @@ -539,6 +552,13 @@ def before_training(self, strategy: "BaseSGDTemplate", *args, **kwargs):
f"Strategy has only {self.total_epochs} training epochs."
)

def before_training_exp(self, strategy: "BaseSGDTemplate", *args, **kwargs):
# Reset the optimizer to prevent momentum from affecting the pruned
# parameters
strategy.optimizer = strategy.optimizer.__class__(
strategy.model.parameters(), **strategy.optimizer.defaults
)

def before_training_epoch(self, strategy: "BaseSGDTemplate", *args, **kwargs):
"""When the initial training phase is over, prune the model and
transition to the post-pruning phase.
Expand All @@ -547,7 +567,7 @@ def before_training_epoch(self, strategy: "BaseSGDTemplate", *args, **kwargs):
model = self._get_model(strategy)

if epoch == (self.total_epochs - self.post_prune_epochs):
model.prune(self.prune_proportion)
model.prune(self.prune_proportion(strategy.clock.train_exp_counter))

def after_training_exp(self, strategy: "Template", *args, **kwargs):
"""After each experience, commit the model so that the next experience
Expand Down
52 changes: 37 additions & 15 deletions tests/models/test_packnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from avalanche.models.packnet import PackNetModel, packnet_simple_mlp
from avalanche.training.supervised.strategy_wrappers import PackNet
from torch.optim import SGD
from torch.optim import SGD, Adam
import torch
import os

Expand All @@ -10,15 +10,7 @@


class TestPackNet(unittest.TestCase):
_EXPECTATIONS = {
"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000": 0.75,
"Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001": 0.45,
"Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002": 0.33,
"Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003": 0.0886,
"Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004": 0.0,
}

def test_PackNetPlugin(self):
def _test_PackNetPlugin(self, expectations, optimizer_constructor, lr):
torch.manual_seed(0)
if "USE_GPU" in os.environ:
use_gpu = os.environ["USE_GPU"].lower() in ["true"]
Expand All @@ -33,7 +25,7 @@ def test_PackNetPlugin(self):
hidden_size=20,
)
model = construct_model()
optimizer = SGD(model.parameters(), lr=0.1)
optimizer = optimizer_constructor(model.parameters(), lr=lr)
strategy = PackNet(
model,
prune_proportion=0.5,
Expand All @@ -51,11 +43,15 @@ def test_PackNetPlugin(self):
# Train
for i, experience in enumerate(scenario.train_stream):
strategy.train(experience)
# Assert that the model achieves the expected accuracy for each task
# the model has trained on so far. PackNet should not degrade
# performance on previous tasks.
self.assert_eval(expectations, strategy.eval(scenario.test_stream), i)
# Store the model output for each task
task_ouputs.append(model.forward(x_test, t_test * i))

# Check that the model achieves the expected accuracy
self.assert_eval(strategy.eval(scenario.test_stream))
self.assert_eval(expectations, strategy.eval(scenario.test_stream), 5)

# Verify the model can be saved and loaded from a state dict
new_model = construct_model()
Expand All @@ -67,13 +63,39 @@ def test_PackNetPlugin(self):
strategy.model = new_model

# Check that the loaded model achieves the expected accuracy
self.assert_eval(strategy.eval(scenario.test_stream))
self.assert_eval(expectations, strategy.eval(scenario.test_stream), 5)

# Ensure that given the same inputs, the model produces the same outputs
for i, task_out in enumerate(task_ouputs):
out = model.forward(x_test, t_test * i)
self.assertTrue(torch.allclose(out, task_out))

def test_PackNet_adam(self):
self._test_PackNetPlugin(
[
("Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000", 0.500),
("Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001", 0.410),
("Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002", 0.134),
("Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003", 0.240),
("Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004", 0.163),
],
optimizer_constructor=Adam,
lr=0.2,
)

def test_PackNet_sgd(self):
self._test_PackNetPlugin(
[
("Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000", 0.75),
("Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001", 0.49),
("Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002", 0.099),
("Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003", 0.14),
("Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004", 0.0),
],
optimizer_constructor=SGD,
lr=0.1,
)

def test_unsupported_exception(self):
"""Expect an exception when trying to wrap an unsupported module"""

Expand All @@ -88,8 +110,8 @@ def __init__(self):
with self.assertRaises(ValueError):
PackNetModel(torch.nn.BatchNorm2d(10))

def assert_eval(self, last_eval):
for metric, value in self._EXPECTATIONS.items():
def assert_eval(self, expectations, last_eval, task_id):
for metric, value in expectations[:task_id]:
self.assertAlmostEqual(last_eval[metric], value, places=2)


Expand Down

0 comments on commit 74afbb6

Please sign in to comment.