Skip to content

Commit fcb6258

Browse files
committed
Add missing leaky_relu layer factory defn, update Apex/Native loss scaler interfaces to support unscaled grad clipping. Bump ver to 0.2.2 for pending release.
1 parent 186075e commit fcb6258

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

Diff for: timm/models/layers/create_act.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
mish=Mish,
4747
relu=nn.ReLU,
4848
relu6=nn.ReLU6,
49+
leaky_relu=nn.LeakyReLU,
4950
elu=nn.ELU,
5051
prelu=nn.PReLU,
5152
celu=nn.CELU,

Diff for: timm/utils/cuda.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
class ApexScaler:
1616
state_dict_key = "amp"
1717

18-
def __call__(self, loss, optimizer):
18+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
1919
with amp.scale_loss(loss, optimizer) as scaled_loss:
2020
scaled_loss.backward()
21+
if clip_grad:
22+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
2123
optimizer.step()
2224

2325
def state_dict(self):
@@ -35,8 +37,12 @@ class NativeScaler:
3537
def __init__(self):
3638
self._scaler = torch.cuda.amp.GradScaler()
3739

38-
def __call__(self, loss, optimizer):
40+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
3941
self._scaler.scale(loss).backward()
42+
if clip_grad:
43+
assert parameters is not None
44+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
45+
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
4046
self._scaler.step(optimizer)
4147
self._scaler.update()
4248

Diff for: timm/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.1'
1+
__version__ = '0.2.2'

0 commit comments

Comments
 (0)