Skip to content

Commit 8803c85

Browse files
authored
Added NewClassesCrossEntropy criterion and automatic criterion plugin (#1514)
* Added NewClassesCrossEntropy criterion and automatic criterion plugin adding * change to maskedcrossentropy with 3 modes * added seen and all options, switch to torch functional kl div * added stable softmax
1 parent fca0ca0 commit 8803c85

File tree

4 files changed

+109
-7
lines changed

4 files changed

+109
-7
lines changed

avalanche/training/losses.py

+64-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import copy
22

3+
import numpy as np
34
import torch
5+
import torch.nn.functional as F
46
from torch import nn
5-
from avalanche.training.plugins import SupervisedPlugin
67
from torch.nn import BCELoss
7-
import numpy as np
8+
9+
from avalanche.training.plugins import SupervisedPlugin
10+
from avalanche.training.regularization import cross_entropy_with_oh_targets
811

912

1013
class ICaRLLossPlugin(SupervisedPlugin):
@@ -161,4 +164,62 @@ def forward(self, features, labels=None, mask=None):
161164
return loss
162165

163166

164-
__all__ = ["ICaRLLossPlugin", "SCRLoss"]
167+
class MaskedCrossEntropy(SupervisedPlugin):
168+
"""
169+
Masked Cross Entropy
170+
171+
This criterion can be used for instance in Class Incremental
172+
Learning Problems when no examplars are used
173+
(i.e LwF in Class Incremental Learning would need to use mask="new").
174+
"""
175+
176+
def __init__(self, classes=None, mask="seen", reduction="mean"):
177+
"""
178+
param: classes: Initial value for current classes
179+
param: mask: "all" normal cross entropy, uses all the classes seen so far
180+
"old" cross entropy only on the old classes
181+
"new" cross entropy only on the new classes
182+
param: reduction: "mean" or "none", average or per-sample loss
183+
"""
184+
super().__init__()
185+
assert mask in ["seen", "new", "old", "all"]
186+
if classes is not None:
187+
self.current_classes = set(classes)
188+
else:
189+
self.current_classes = set()
190+
191+
self.old_classes = set()
192+
self.reduction = reduction
193+
self.mask = mask
194+
195+
def __call__(self, logits, targets):
196+
oh_targets = F.one_hot(targets, num_classes=logits.shape[1])
197+
198+
oh_targets = oh_targets[:, self.current_mask(logits.shape[1])]
199+
logits = logits[:, self.current_mask(logits.shape[1])]
200+
201+
return cross_entropy_with_oh_targets(
202+
logits,
203+
oh_targets.float(),
204+
reduction=self.reduction,
205+
)
206+
207+
def current_mask(self, logit_shape):
208+
if self.mask == "seen":
209+
return list(self.current_classes.union(self.old_classes))
210+
elif self.mask == "new":
211+
return list(self.current_classes)
212+
elif self.mask == "old":
213+
return list(self.old_classes)
214+
elif self.mask == "all":
215+
return list(range(int(logit_shape)))
216+
217+
def adaptation(self, new_classes):
218+
self.old_classes = self.old_classes.union(self.current_classes)
219+
self.current_classes = set(new_classes)
220+
221+
def before_training_exp(self, strategy, **kwargs):
222+
self.adaptation(strategy.experience.classes_in_this_experience)
223+
224+
225+
__all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"]

avalanche/training/regularization.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,25 @@
99
from avalanche.models import MultiTaskModule, avalanche_forward
1010

1111

12-
def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5):
12+
def stable_softmax(x):
13+
z = x - torch.max(x, dim=1, keepdim=True)[0]
14+
numerator = torch.exp(z)
15+
denominator = torch.sum(numerator, dim=1, keepdim=True)
16+
softmax = numerator / denominator
17+
return softmax
18+
19+
20+
def cross_entropy_with_oh_targets(outputs, targets, reduction="mean"):
1321
"""Calculates cross-entropy with temperature scaling,
1422
targets can also be soft targets but they must sum to 1"""
15-
outputs = torch.nn.functional.softmax(outputs, dim=1)
23+
outputs = stable_softmax(outputs)
1624
ce = -(targets * outputs.log()).sum(1)
17-
ce = ce.mean()
25+
if reduction == "mean":
26+
ce = ce.mean()
27+
elif reduction == "none":
28+
return ce
29+
else:
30+
raise NotImplementedError("reduction must be mean or none")
1831
return ce
1932

2033

avalanche/training/templates/base_sgd.py

+3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def __init__(
9696
self._criterion = criterion
9797
""" Criterion. """
9898

99+
if criterion not in self.plugins and isinstance(criterion, BasePlugin):
100+
self.plugins.append(criterion)
101+
99102
self.train_epochs: int = train_epochs
100103
""" Number of training epochs. """
101104

tests/training/test_losses.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
22

33
import torch
4-
from avalanche.training.losses import ICaRLLossPlugin
4+
import torch.nn as nn
5+
6+
from avalanche.training.losses import ICaRLLossPlugin, MaskedCrossEntropy
57

68

79
class TestICaRLLossPlugin(unittest.TestCase):
@@ -34,5 +36,28 @@ def test_loss(self):
3436
assert loss3 == loss1
3537

3638

39+
class TestMaskedCrossEntropy(unittest.TestCase):
40+
def test_loss(self):
41+
cross_entropy = nn.CrossEntropyLoss()
42+
43+
criterion = MaskedCrossEntropy(mask="new")
44+
criterion.adaptation([1, 2, 3, 4])
45+
criterion.adaptation([5, 6, 7])
46+
47+
mb_y = torch.tensor([5, 5, 6, 7, 6])
48+
49+
new_pred = torch.rand(5, 8)
50+
new_pred_new = new_pred[:, criterion.current_mask(new_pred.shape[1])]
51+
52+
loss1 = criterion(new_pred, mb_y)
53+
loss2 = cross_entropy(new_pred_new, mb_y - 5)
54+
55+
criterion.mask = "seen"
56+
loss3 = criterion(new_pred, mb_y)
57+
58+
self.assertAlmostEqual(float(loss1), float(loss2), places=5)
59+
self.assertNotAlmostEqual(float(loss1), float(loss3), places=5)
60+
61+
3762
if __name__ == "__main__":
3863
unittest.main()

0 commit comments

Comments
 (0)