Skip to content

Commit

Permalink
added seen and all options, switch to torch functional kl div
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Oct 11, 2023
1 parent e5f6ed5 commit a1025cc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
24 changes: 13 additions & 11 deletions avalanche/training/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class MaskedCrossEntropy(SupervisedPlugin):
(i.e LwF in Class Incremental Learning would need to use mask="new").
"""

def __init__(self, classes=None, mask="all", reduction="mean"):
def __init__(self, classes=None, mask="seen", reduction="batchmean"):
"""
param: classes: Initial value for current classes
param: mask: "all" normal cross entropy, uses all the classes seen so far
Expand All @@ -182,7 +182,7 @@ def __init__(self, classes=None, mask="all", reduction="mean"):
param: reduction: "mean" or "none", average or per-sample loss
"""
super().__init__()
assert mask in ["all", "new", "old"]
assert mask in ["seen", "new", "old", "all"]
if classes is not None:
self.current_classes = set(classes)
else:
Expand All @@ -195,23 +195,25 @@ def __init__(self, classes=None, mask="all", reduction="mean"):
def __call__(self, logits, targets):
oh_targets = F.one_hot(targets, num_classes=logits.shape[1])

oh_targets = oh_targets[:, self.current_mask]
logits = logits[:, self.current_mask]
oh_targets = oh_targets[:, self.current_mask(logits.shape[1])]
logits = logits[:, self.current_mask(logits.shape[1])]

return cross_entropy_with_oh_targets(
logits,
return F.kl_div(
torch.log_softmax(logits, dim=1),
oh_targets.float(),
reduction=self.reduction,
log_target=False,
)

@property
def current_mask(self):
if self.mask == "all":
def current_mask(self, logit_shape):
if self.mask == "seen":
return list(self.current_classes.union(self.old_classes))
if self.mask == "new":
elif self.mask == "new":
return list(self.current_classes)
if self.mask == "old":
elif self.mask == "old":
return list(self.old_classes)
elif self.mask == "all":
return list(range(int(logit_shape)))

def adaptation(self, new_classes):
self.old_classes = self.old_classes.union(self.current_classes)
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from avalanche.models import MultiTaskModule, avalanche_forward


def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5, reduction="mean"):
def cross_entropy_with_oh_targets(outputs, targets, reduction="mean"):
"""Calculates cross-entropy with temperature scaling,
targets can also be soft targets but they must sum to 1"""
outputs = torch.nn.functional.softmax(outputs, dim=1)
Expand Down
5 changes: 3 additions & 2 deletions tests/training/test_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest

import torch
import torch.nn as nn

Expand Down Expand Up @@ -46,12 +47,12 @@ def test_loss(self):
mb_y = torch.tensor([5, 5, 6, 7, 6])

new_pred = torch.rand(5, 8)
new_pred_new = new_pred[:, criterion.current_mask]
new_pred_new = new_pred[:, criterion.current_mask(new_pred.shape[1])]

loss1 = criterion(new_pred, mb_y)
loss2 = cross_entropy(new_pred_new, mb_y - 5)

criterion.mask = "all"
criterion.mask = "seen"
loss3 = criterion(new_pred, mb_y)

self.assertAlmostEqual(float(loss1), float(loss2), places=5)
Expand Down

0 comments on commit a1025cc

Please sign in to comment.