Skip to content

Commit 27fadaa

Browse files
author
talrid
committed
asymmetric_loss
1 parent 79e727e commit 27fadaa

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

Diff for: timm/loss/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
2-
from .jsd import JsdCrossEntropy
2+
from .jsd import JsdCrossEntropy
3+
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel

Diff for: timm/loss/asymmetric_loss.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class AsymmetricLossMultiLabel(nn.Module):
6+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
7+
super(AsymmetricLossMultiLabel, self).__init__()
8+
9+
self.gamma_neg = gamma_neg
10+
self.gamma_pos = gamma_pos
11+
self.clip = clip
12+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
13+
self.eps = eps
14+
15+
def forward(self, x, y):
16+
""""
17+
Parameters
18+
----------
19+
x: input logits
20+
y: targets (multi-label binarized vector)
21+
"""
22+
23+
# Calculating Probabilities
24+
x_sigmoid = torch.sigmoid(x)
25+
xs_pos = x_sigmoid
26+
xs_neg = 1 - x_sigmoid
27+
28+
# Asymmetric Clipping
29+
if self.clip is not None and self.clip > 0:
30+
xs_neg = (xs_neg + self.clip).clamp(max=1)
31+
32+
# Basic CE calculation
33+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
34+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
35+
loss = los_pos + los_neg
36+
37+
# Asymmetric Focusing
38+
if self.gamma_neg > 0 or self.gamma_pos > 0:
39+
if self.disable_torch_grad_focal_loss:
40+
torch._C.set_grad_enabled(False)
41+
pt0 = xs_pos * y
42+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
43+
pt = pt0 + pt1
44+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
45+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
46+
if self.disable_torch_grad_focal_loss:
47+
torch._C.set_grad_enabled(True)
48+
loss *= one_sided_w
49+
50+
return -loss.sum()
51+
52+
53+
class AsymmetricLossSingleLabel(nn.Module):
54+
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
55+
super(AsymmetricLossSingleLabel, self).__init__()
56+
57+
self.eps = eps
58+
self.logsoftmax = nn.LogSoftmax(dim=-1)
59+
self.targets_classes = [] # prevent gpu repeated memory allocation
60+
self.gamma_pos = gamma_pos
61+
self.gamma_neg = gamma_neg
62+
self.reduction = reduction
63+
64+
def forward(self, inputs, target, reduction=None):
65+
""""
66+
Parameters
67+
----------
68+
x: input logits
69+
y: targets (1-hot vector)
70+
"""
71+
72+
num_classes = inputs.size()[-1]
73+
log_preds = self.logsoftmax(inputs)
74+
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
75+
76+
# ASL weights
77+
targets = self.targets_classes
78+
anti_targets = 1 - targets
79+
xs_pos = torch.exp(log_preds)
80+
xs_neg = 1 - xs_pos
81+
xs_pos = xs_pos * targets
82+
xs_neg = xs_neg * anti_targets
83+
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
84+
self.gamma_pos * targets + self.gamma_neg * anti_targets)
85+
log_preds = log_preds * asymmetric_w
86+
87+
if self.eps > 0: # label smoothing
88+
self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
89+
90+
# loss calculation
91+
loss = - self.targets_classes.mul(log_preds)
92+
93+
loss = loss.sum(dim=-1)
94+
if self.reduction == 'mean':
95+
loss = loss.mean()
96+
97+
return loss

0 commit comments

Comments
 (0)