Skip to content

Commit efbd292

Browse files
authored
Feature: CurricularFace (#1013)
* curricularface module * docs & link to official implementation * end string * using `weights` instead `kernel` * tests for curricularface * using `torch.mm` instead F.linear
1 parent 9c467e3 commit efbd292

File tree

4 files changed

+241
-1
lines changed

4 files changed

+241
-1
lines changed

catalyst/contrib/nn/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Normalize,
1111
)
1212
from catalyst.contrib.nn.modules.cosface import CosFace, AdaCos
13+
from catalyst.contrib.nn.modules.curricularface import CurricularFace
1314
from catalyst.contrib.nn.modules.lama import (
1415
LamaPooling,
1516
TemporalLastPooling,
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
8+
class CurricularFace(nn.Module):
9+
"""Implementation of
10+
`CurricularFace: Adaptive Curriculum Learning\
11+
Loss for Deep Face Recognition`_.
12+
13+
.. _CurricularFace\: Adaptive Curriculum Learning\
14+
Loss for Deep Face Recognition:
15+
https://arxiv.org/abs/2004.00288
16+
17+
Official `pytorch implementation`_.
18+
19+
.. _pytorch implementation:
20+
https://github.com/HuangYG123/CurricularFace
21+
22+
Args:
23+
in_features: size of each input sample.
24+
out_features: size of each output sample.
25+
s: norm of input feature.
26+
Default: ``64.0``.
27+
m: margin.
28+
Default: ``0.5``.
29+
30+
Shape:
31+
- Input: :math:`(batch, H_{in})` where
32+
:math:`H_{in} = in\_features`.
33+
- Output: :math:`(batch, H_{out})` where
34+
:math:`H_{out} = out\_features`.
35+
36+
Example:
37+
>>> layer = CurricularFace(5, 10, s=1.31, m=0.5)
38+
>>> loss_fn = nn.CrosEntropyLoss()
39+
>>> embedding = torch.randn(3, 5, requires_grad=True)
40+
>>> target = torch.empty(3, dtype=torch.long).random_(10)
41+
>>> output = layer(embedding, target)
42+
>>> loss = loss_fn(output, target)
43+
>>> loss.backward()
44+
45+
""" # noqa: RST215
46+
47+
def __init__( # noqa: D107
48+
self,
49+
in_features: int,
50+
out_features: int,
51+
s: float = 64.0,
52+
m: float = 0.5,
53+
):
54+
super(CurricularFace, self).__init__()
55+
56+
self.in_features = in_features
57+
self.out_features = out_features
58+
self.m = m
59+
self.s = s
60+
61+
self.cos_m = math.cos(m)
62+
self.sin_m = math.sin(m)
63+
self.threshold = math.cos(math.pi - m)
64+
self.mm = math.sin(math.pi - m) * m
65+
66+
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
67+
self.register_buffer("t", torch.zeros(1))
68+
69+
nn.init.normal_(self.weight, std=0.01)
70+
71+
def __repr__(self) -> str: # noqa: D105
72+
rep = (
73+
"CurricularFace("
74+
f"in_features={self.in_features},"
75+
f"out_features={self.out_features},"
76+
f"m={self.m},s={self.s}"
77+
")"
78+
)
79+
return rep
80+
81+
def forward(
82+
self, input: torch.Tensor, label: torch.LongTensor
83+
) -> torch.Tensor:
84+
"""
85+
Args:
86+
input: input features,
87+
expected shapes ``BxF`` where ``B``
88+
is batch dimension and ``F`` is an
89+
input feature dimension.
90+
label: target classes,
91+
expected shapes ``B`` where
92+
``B`` is batch dimension.
93+
94+
Returns:
95+
tensor (logits) with shapes ``BxC``
96+
where ``C`` is a number of classes.
97+
"""
98+
cos_theta = torch.mm(
99+
F.normalize(input), F.normalize(self.weight, dim=0)
100+
)
101+
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
102+
103+
target_logit = cos_theta[torch.arange(0, input.size(0)), label].view(
104+
-1, 1
105+
)
106+
107+
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
108+
cos_theta_m = (
109+
target_logit * self.cos_m - sin_theta * self.sin_m
110+
) # cos(target+margin)
111+
mask = cos_theta > cos_theta_m
112+
final_target_logit = torch.where(
113+
target_logit > self.threshold, cos_theta_m, target_logit - self.mm
114+
)
115+
116+
hard_example = cos_theta[mask]
117+
with torch.no_grad():
118+
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
119+
120+
cos_theta[mask] = hard_example * (self.t + hard_example)
121+
cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit)
122+
output = cos_theta * self.s
123+
124+
return output
125+
126+
127+
__all__ = ["CurricularFace"]

catalyst/contrib/nn/tests/test_modules.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import torch
55
import torch.nn as nn
66

7-
from catalyst.contrib.nn.modules import ArcFace, CosFace, SoftMax
7+
from catalyst.contrib.nn.modules import (
8+
ArcFace,
9+
CosFace,
10+
CurricularFace,
11+
SoftMax,
12+
)
813

914

1015
def normalize(m: np.ndarray) -> np.ndarray:
@@ -209,3 +214,103 @@ def test_cosface_with_cross_entropy_loss():
209214
.numpy()
210215
)
211216
assert np.isclose(expected_loss.sum(), actual)
217+
218+
219+
def test_curricularface_with_cross_entropy_loss():
220+
emb_size = 4
221+
n_classes = 3
222+
s = 3.0
223+
m = 0.1
224+
225+
# fmt: off
226+
features = np.array(
227+
[
228+
[1, 2, 3, 4],
229+
[5, 6, 7, 8],
230+
],
231+
dtype="f",
232+
)
233+
target = np.array([0, 2], dtype="l")
234+
mask = np.array([[1, 0, 0], [0, 0, 1]], dtype="l") # one_hot(target)
235+
236+
weight = np.array(
237+
[
238+
[0.1, 0.2, 0.3, 0.4],
239+
[1.1, 3.2, 5.3, 0.4],
240+
[0.1, 0.2, 6.3, 0.4],
241+
],
242+
dtype="f",
243+
)
244+
# fmt: on
245+
246+
layer = CurricularFace(emb_size, n_classes, s, m)
247+
layer.weight.data = torch.from_numpy(weight.T)
248+
loss_fn = nn.CrossEntropyLoss(reduction="none")
249+
250+
normalized_features = normalize(features) # 2x4
251+
normalized_projection = normalize(weight) # 3x4
252+
253+
cosine = normalized_features @ normalized_projection.T # 2x4 * 4x3 = 2x3
254+
logit = cosine[mask.astype(np.bool)].reshape(-1, 1)
255+
256+
sine = np.sqrt(1.0 - np.power(logit, 2))
257+
cos_theta_m = logit * np.cos(m) - sine * np.sin(m)
258+
259+
final_logit = np.where(
260+
logit > np.cos(np.pi - m), cos_theta_m, logit - np.sin(np.pi - m) * m,
261+
)
262+
263+
cos_mask = cosine > cos_theta_m
264+
hard = cosine[cos_mask]
265+
266+
t = np.mean(logit) * 0.01 - (1 - 0.01) * 0
267+
268+
cosine[cos_mask] = hard * (t + hard) # 2x3
269+
for r, c in enumerate(target):
270+
cosine[r, c] = final_logit[r, 0]
271+
cosine = cosine * s # 2x3
272+
273+
expected_loss = cross_entropy(cosine, mask, 1)
274+
actual = (
275+
loss_fn(
276+
layer(torch.from_numpy(features), torch.LongTensor(target)),
277+
torch.LongTensor(target),
278+
)
279+
.detach()
280+
.numpy()
281+
)
282+
283+
assert np.allclose(expected_loss, actual)
284+
285+
# reinitialize layer (t is changed)
286+
layer = CurricularFace(emb_size, n_classes, s, m)
287+
layer.weight.data = torch.from_numpy(weight.T)
288+
loss_fn = nn.CrossEntropyLoss(reduction="mean")
289+
290+
expected_loss = cross_entropy(cosine, mask, 1)
291+
actual = (
292+
loss_fn(
293+
layer(torch.from_numpy(features), torch.LongTensor(target)),
294+
torch.LongTensor(target),
295+
)
296+
.detach()
297+
.numpy()
298+
)
299+
300+
assert np.isclose(expected_loss.mean(), actual)
301+
302+
# reinitialize layer (t is changed)
303+
layer = CurricularFace(emb_size, n_classes, s, m)
304+
layer.weight.data = torch.from_numpy(weight.T)
305+
loss_fn = nn.CrossEntropyLoss(reduction="sum")
306+
307+
expected_loss = cross_entropy(cosine, mask, 1)
308+
actual = (
309+
loss_fn(
310+
layer(torch.from_numpy(features), torch.LongTensor(target)),
311+
torch.LongTensor(target),
312+
)
313+
.detach()
314+
.numpy()
315+
)
316+
assert np.isclose(expected_loss.sum(), actual)

docs/api/contrib.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ CosFace and AdaCos
183183
:undoc-members:
184184
:show-inheritance:
185185

186+
CurricularFace
187+
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
188+
.. automodule:: catalyst.contrib.nn.modules.curricularface
189+
:members:
190+
:undoc-members:
191+
:show-inheritance:
192+
186193
Last-Mean-Average-Attention (LAMA)-Pooling
187194
""""""""""""""""""""""""""""""""""""""""""
188195
.. automodule:: catalyst.contrib.nn.modules.lama

0 commit comments

Comments
 (0)