Skip to content

Commit b2c10fe

Browse files
authored
Merge pull request #2168 from huggingface/more_vit_better_getter_redux
A few more features_intermediate() models, AttentionExtract helper, related minor cleanup.
2 parents f8979d4 + 1d3ab17 commit b2c10fe

34 files changed

+479
-288
lines changed

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
FEAT_INTER_FILTERS = [
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
54-
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
54+
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
5556
]
5657

5758
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
4848
from .selective_kernel import SelectiveKernel
4949
from .separable_conv import SeparableConv2d, SeparableConvNormAct
50-
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
50+
from .space_to_depth import SpaceToDepth, DepthToSpace
5151
from .split_attn import SplitAttn
5252
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
5353
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

timm/layers/activations_jit.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

timm/layers/activations_me.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
A collection of activations fn and modules with a common interface so that they can
44
easily be swapped. All have an `inplace` arg even if not used.
55
6-
These activations are not compatible with jit scripting or ONNX export of the model, please use either
7-
the JIT or basic versions of the activations.
6+
These activations are not compatible with jit scripting or ONNX export of the model, please use
7+
basic versions of the activations.
88
99
Hacked together by / Copyright 2020 Ross Wightman
1010
"""
@@ -14,19 +14,17 @@
1414
from torch.nn import functional as F
1515

1616

17-
@torch.jit.script
18-
def swish_jit_fwd(x):
17+
def swish_fwd(x):
1918
return x.mul(torch.sigmoid(x))
2019

2120

22-
@torch.jit.script
23-
def swish_jit_bwd(x, grad_output):
21+
def swish_bwd(x, grad_output):
2422
x_sigmoid = torch.sigmoid(x)
2523
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
2624

2725

28-
class SwishJitAutoFn(torch.autograd.Function):
29-
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
26+
class SwishAutoFn(torch.autograd.Function):
27+
""" optimised Swish w/ memory-efficient checkpoint
3028
Inspired by conversation btw Jeremy Howard & Adam Pazske
3129
https://twitter.com/jeremyphoward/status/1188251041835315200
3230
"""
@@ -37,123 +35,117 @@ def symbolic(g, x):
3735
@staticmethod
3836
def forward(ctx, x):
3937
ctx.save_for_backward(x)
40-
return swish_jit_fwd(x)
38+
return swish_fwd(x)
4139

4240
@staticmethod
4341
def backward(ctx, grad_output):
4442
x = ctx.saved_tensors[0]
45-
return swish_jit_bwd(x, grad_output)
43+
return swish_bwd(x, grad_output)
4644

4745

4846
def swish_me(x, inplace=False):
49-
return SwishJitAutoFn.apply(x)
47+
return SwishAutoFn.apply(x)
5048

5149

5250
class SwishMe(nn.Module):
5351
def __init__(self, inplace: bool = False):
5452
super(SwishMe, self).__init__()
5553

5654
def forward(self, x):
57-
return SwishJitAutoFn.apply(x)
55+
return SwishAutoFn.apply(x)
5856

5957

60-
@torch.jit.script
61-
def mish_jit_fwd(x):
58+
def mish_fwd(x):
6259
return x.mul(torch.tanh(F.softplus(x)))
6360

6461

65-
@torch.jit.script
66-
def mish_jit_bwd(x, grad_output):
62+
def mish_bwd(x, grad_output):
6763
x_sigmoid = torch.sigmoid(x)
6864
x_tanh_sp = F.softplus(x).tanh()
6965
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
7066

7167

72-
class MishJitAutoFn(torch.autograd.Function):
68+
class MishAutoFn(torch.autograd.Function):
7369
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
74-
A memory efficient, jit scripted variant of Mish
70+
A memory efficient variant of Mish
7571
"""
7672
@staticmethod
7773
def forward(ctx, x):
7874
ctx.save_for_backward(x)
79-
return mish_jit_fwd(x)
75+
return mish_fwd(x)
8076

8177
@staticmethod
8278
def backward(ctx, grad_output):
8379
x = ctx.saved_tensors[0]
84-
return mish_jit_bwd(x, grad_output)
80+
return mish_bwd(x, grad_output)
8581

8682

8783
def mish_me(x, inplace=False):
88-
return MishJitAutoFn.apply(x)
84+
return MishAutoFn.apply(x)
8985

9086

9187
class MishMe(nn.Module):
9288
def __init__(self, inplace: bool = False):
9389
super(MishMe, self).__init__()
9490

9591
def forward(self, x):
96-
return MishJitAutoFn.apply(x)
92+
return MishAutoFn.apply(x)
9793

9894

99-
@torch.jit.script
100-
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
95+
def hard_sigmoid_fwd(x, inplace: bool = False):
10196
return (x + 3).clamp(min=0, max=6).div(6.)
10297

10398

104-
@torch.jit.script
105-
def hard_sigmoid_jit_bwd(x, grad_output):
99+
def hard_sigmoid_bwd(x, grad_output):
106100
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
107101
return grad_output * m
108102

109103

110-
class HardSigmoidJitAutoFn(torch.autograd.Function):
104+
class HardSigmoidAutoFn(torch.autograd.Function):
111105
@staticmethod
112106
def forward(ctx, x):
113107
ctx.save_for_backward(x)
114-
return hard_sigmoid_jit_fwd(x)
108+
return hard_sigmoid_fwd(x)
115109

116110
@staticmethod
117111
def backward(ctx, grad_output):
118112
x = ctx.saved_tensors[0]
119-
return hard_sigmoid_jit_bwd(x, grad_output)
113+
return hard_sigmoid_bwd(x, grad_output)
120114

121115

122116
def hard_sigmoid_me(x, inplace: bool = False):
123-
return HardSigmoidJitAutoFn.apply(x)
117+
return HardSigmoidAutoFn.apply(x)
124118

125119

126120
class HardSigmoidMe(nn.Module):
127121
def __init__(self, inplace: bool = False):
128122
super(HardSigmoidMe, self).__init__()
129123

130124
def forward(self, x):
131-
return HardSigmoidJitAutoFn.apply(x)
125+
return HardSigmoidAutoFn.apply(x)
132126

133127

134-
@torch.jit.script
135-
def hard_swish_jit_fwd(x):
128+
def hard_swish_fwd(x):
136129
return x * (x + 3).clamp(min=0, max=6).div(6.)
137130

138131

139-
@torch.jit.script
140-
def hard_swish_jit_bwd(x, grad_output):
132+
def hard_swish_bwd(x, grad_output):
141133
m = torch.ones_like(x) * (x >= 3.)
142134
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
143135
return grad_output * m
144136

145137

146-
class HardSwishJitAutoFn(torch.autograd.Function):
147-
"""A memory efficient, jit-scripted HardSwish activation"""
138+
class HardSwishAutoFn(torch.autograd.Function):
139+
"""A memory efficient HardSwish activation"""
148140
@staticmethod
149141
def forward(ctx, x):
150142
ctx.save_for_backward(x)
151-
return hard_swish_jit_fwd(x)
143+
return hard_swish_fwd(x)
152144

153145
@staticmethod
154146
def backward(ctx, grad_output):
155147
x = ctx.saved_tensors[0]
156-
return hard_swish_jit_bwd(x, grad_output)
148+
return hard_swish_bwd(x, grad_output)
157149

158150
@staticmethod
159151
def symbolic(g, self):
@@ -164,55 +156,53 @@ def symbolic(g, self):
164156

165157

166158
def hard_swish_me(x, inplace=False):
167-
return HardSwishJitAutoFn.apply(x)
159+
return HardSwishAutoFn.apply(x)
168160

169161

170162
class HardSwishMe(nn.Module):
171163
def __init__(self, inplace: bool = False):
172164
super(HardSwishMe, self).__init__()
173165

174166
def forward(self, x):
175-
return HardSwishJitAutoFn.apply(x)
167+
return HardSwishAutoFn.apply(x)
176168

177169

178-
@torch.jit.script
179-
def hard_mish_jit_fwd(x):
170+
def hard_mish_fwd(x):
180171
return 0.5 * x * (x + 2).clamp(min=0, max=2)
181172

182173

183-
@torch.jit.script
184-
def hard_mish_jit_bwd(x, grad_output):
174+
def hard_mish_bwd(x, grad_output):
185175
m = torch.ones_like(x) * (x >= -2.)
186176
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
187177
return grad_output * m
188178

189179

190-
class HardMishJitAutoFn(torch.autograd.Function):
191-
""" A memory efficient, jit scripted variant of Hard Mish
180+
class HardMishAutoFn(torch.autograd.Function):
181+
""" A memory efficient variant of Hard Mish
192182
Experimental, based on notes by Mish author Diganta Misra at
193183
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
194184
"""
195185
@staticmethod
196186
def forward(ctx, x):
197187
ctx.save_for_backward(x)
198-
return hard_mish_jit_fwd(x)
188+
return hard_mish_fwd(x)
199189

200190
@staticmethod
201191
def backward(ctx, grad_output):
202192
x = ctx.saved_tensors[0]
203-
return hard_mish_jit_bwd(x, grad_output)
193+
return hard_mish_bwd(x, grad_output)
204194

205195

206196
def hard_mish_me(x, inplace: bool = False):
207-
return HardMishJitAutoFn.apply(x)
197+
return HardMishAutoFn.apply(x)
208198

209199

210200
class HardMishMe(nn.Module):
211201
def __init__(self, inplace: bool = False):
212202
super(HardMishMe, self).__init__()
213203

214204
def forward(self, x):
215-
return HardMishJitAutoFn.apply(x)
205+
return HardMishAutoFn.apply(x)
216206

217207

218208

0 commit comments

Comments
 (0)