|
1 | 1 | import copy
|
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | import torch
|
| 5 | +import torch.nn.functional as F |
4 | 6 | from torch import nn
|
5 |
| -from avalanche.training.plugins import SupervisedPlugin |
6 | 7 | from torch.nn import BCELoss
|
7 |
| -import numpy as np |
| 8 | + |
| 9 | +from avalanche.training.plugins import SupervisedPlugin |
| 10 | +from avalanche.training.regularization import cross_entropy_with_oh_targets |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class ICaRLLossPlugin(SupervisedPlugin):
|
@@ -161,4 +164,62 @@ def forward(self, features, labels=None, mask=None):
|
161 | 164 | return loss
|
162 | 165 |
|
163 | 166 |
|
164 |
| -__all__ = ["ICaRLLossPlugin", "SCRLoss"] |
| 167 | +class MaskedCrossEntropy(SupervisedPlugin): |
| 168 | + """ |
| 169 | + Masked Cross Entropy |
| 170 | +
|
| 171 | + This criterion can be used for instance in Class Incremental |
| 172 | + Learning Problems when no examplars are used |
| 173 | + (i.e LwF in Class Incremental Learning would need to use mask="new"). |
| 174 | + """ |
| 175 | + |
| 176 | + def __init__(self, classes=None, mask="seen", reduction="mean"): |
| 177 | + """ |
| 178 | + param: classes: Initial value for current classes |
| 179 | + param: mask: "all" normal cross entropy, uses all the classes seen so far |
| 180 | + "old" cross entropy only on the old classes |
| 181 | + "new" cross entropy only on the new classes |
| 182 | + param: reduction: "mean" or "none", average or per-sample loss |
| 183 | + """ |
| 184 | + super().__init__() |
| 185 | + assert mask in ["seen", "new", "old", "all"] |
| 186 | + if classes is not None: |
| 187 | + self.current_classes = set(classes) |
| 188 | + else: |
| 189 | + self.current_classes = set() |
| 190 | + |
| 191 | + self.old_classes = set() |
| 192 | + self.reduction = reduction |
| 193 | + self.mask = mask |
| 194 | + |
| 195 | + def __call__(self, logits, targets): |
| 196 | + oh_targets = F.one_hot(targets, num_classes=logits.shape[1]) |
| 197 | + |
| 198 | + oh_targets = oh_targets[:, self.current_mask(logits.shape[1])] |
| 199 | + logits = logits[:, self.current_mask(logits.shape[1])] |
| 200 | + |
| 201 | + return cross_entropy_with_oh_targets( |
| 202 | + logits, |
| 203 | + oh_targets.float(), |
| 204 | + reduction=self.reduction, |
| 205 | + ) |
| 206 | + |
| 207 | + def current_mask(self, logit_shape): |
| 208 | + if self.mask == "seen": |
| 209 | + return list(self.current_classes.union(self.old_classes)) |
| 210 | + elif self.mask == "new": |
| 211 | + return list(self.current_classes) |
| 212 | + elif self.mask == "old": |
| 213 | + return list(self.old_classes) |
| 214 | + elif self.mask == "all": |
| 215 | + return list(range(int(logit_shape))) |
| 216 | + |
| 217 | + def adaptation(self, new_classes): |
| 218 | + self.old_classes = self.old_classes.union(self.current_classes) |
| 219 | + self.current_classes = set(new_classes) |
| 220 | + |
| 221 | + def before_training_exp(self, strategy, **kwargs): |
| 222 | + self.adaptation(strategy.experience.classes_in_this_experience) |
| 223 | + |
| 224 | + |
| 225 | +__all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"] |
0 commit comments