-
Notifications
You must be signed in to change notification settings - Fork 29
/
loss_func.py
69 lines (56 loc) · 2.74 KB
/
loss_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.nn.functional as F
class CELoss(nn.Module):
def __init__(self):
super().__init__()
self.xent_loss = nn.CrossEntropyLoss()
def forward(self, outputs, targets):
return self.xent_loss(outputs['predicts'], targets)
class SupConLoss(nn.Module):
def __init__(self, alpha, temp):
super().__init__()
self.xent_loss = nn.CrossEntropyLoss()
self.alpha = alpha
self.temp = temp
def nt_xent_loss(self, anchor, target, labels):
with torch.no_grad():
labels = labels.unsqueeze(-1)
mask = torch.eq(labels, labels.transpose(0, 1))
# delete diag elem
mask = mask ^ torch.diag_embed(torch.diag(mask))
# compute logits
anchor_dot_target = torch.einsum('bd,cd->bc', anchor, target) / self.temp
# delete diag elem
anchor_dot_target = anchor_dot_target - torch.diag_embed(torch.diag(anchor_dot_target))
# for numerical stability
logits_max, _ = torch.max(anchor_dot_target, dim=1, keepdim=True)
logits = anchor_dot_target - logits_max.detach()
# compute log prob
exp_logits = torch.exp(logits)
# mask out positives
logits = logits * mask
log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
# in case that mask.sum(1) is zero
mask_sum = mask.sum(dim=1)
mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum)
# compute log-likelihood
pos_logits = (mask * log_prob).sum(dim=1) / mask_sum.detach()
loss = -1 * pos_logits.mean()
return loss
def forward(self, outputs, targets):
normed_cls_feats = F.normalize(outputs['cls_feats'], dim=-1)
ce_loss = (1 - self.alpha) * self.xent_loss(outputs['predicts'], targets)
cl_loss = self.alpha * self.nt_xent_loss(normed_cls_feats, normed_cls_feats, targets)
return ce_loss + cl_loss
class DualLoss(SupConLoss):
def __init__(self, alpha, temp):
super().__init__(alpha, temp)
def forward(self, outputs, targets):
normed_cls_feats = F.normalize(outputs['cls_feats'], dim=-1)
normed_label_feats = F.normalize(outputs['label_feats'], dim=-1)
normed_pos_label_feats = torch.gather(normed_label_feats, dim=1, index=targets.reshape(-1, 1, 1).expand(-1, 1, normed_label_feats.size(-1))).squeeze(1)
ce_loss = (1 - self.alpha) * self.xent_loss(outputs['predicts'], targets)
cl_loss_1 = 0.5 * self.alpha * self.nt_xent_loss(normed_pos_label_feats, normed_cls_feats, targets)
cl_loss_2 = 0.5 * self.alpha * self.nt_xent_loss(normed_cls_feats, normed_pos_label_feats, targets)
return ce_loss + cl_loss_1 + cl_loss_2