Skip to content

Commit ce6585f

Browse files
authored
Merge pull request #556 from rwightman/byoanet-self_attn
ByoaNet - Self Attn Networks - Bottleneck Transformers, Lambda ResNet, HaloNet
2 parents cd3dc49 + b3d7580 commit ce6585f

File tree

15 files changed

+1149
-167
lines changed

15 files changed

+1149
-167
lines changed

tests/test_models.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import fnmatch
66

77
import timm
8-
from timm import list_models, create_model, set_scriptable
8+
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
9+
get_model_default_value
910

1011
if hasattr(torch._C, '_jit_set_profiling_executor'):
1112
# legacy executor is too slow to compile large models for unit tests
@@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
6061
model.eval()
6162

6263
input_size = model.default_cfg['input_size']
63-
if any([x > MAX_BWD_SIZE for x in input_size]):
64-
# cap backward test at 128 * 128 to keep resource usage down
65-
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
64+
if not is_model_default_key(model_name, 'fixed_input_size'):
65+
min_input_size = get_model_default_value(model_name, 'min_input_size')
66+
if min_input_size is not None:
67+
input_size = min_input_size
68+
else:
69+
if any([x > MAX_BWD_SIZE for x in input_size]):
70+
# cap backward test at 128 * 128 to keep resource usage down
71+
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
72+
6673
inputs = torch.randn((batch_size, *input_size))
6774
outputs = model(inputs)
6875
outputs.mean().backward()
@@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
155162
with set_scriptable(True):
156163
model = create_model(model_name, pretrained=False)
157164
model.eval()
158-
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
165+
166+
if has_model_default_key(model_name, 'fixed_input_size'):
167+
input_size = get_model_default_value(model_name, 'input_size')
168+
elif has_model_default_key(model_name, 'min_input_size'):
169+
input_size = get_model_default_value(model_name, 'min_input_size')
170+
else:
171+
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
172+
159173
model = torch.jit.script(model)
160174
outputs = model(torch.randn((batch_size, *input_size)))
161175

@@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
180194
model.eval()
181195
expected_channels = model.feature_info.channels()
182196
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
183-
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
197+
198+
if has_model_default_key(model_name, 'fixed_input_size'):
199+
input_size = get_model_default_value(model_name, 'input_size')
200+
elif has_model_default_key(model_name, 'min_input_size'):
201+
input_size = get_model_default_value(model_name, 'min_input_size')
202+
else:
203+
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
204+
184205
outputs = model(torch.randn((batch_size, *input_size)))
185206
assert len(expected_channels) == len(outputs)
186207
for e, o in zip(expected_channels, outputs):

timm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .version import __version__
22
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3-
is_scriptable, is_exportable, set_scriptable, set_exportable
3+
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
4+
get_model_default_value, is_model_pretrained

timm/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .byoanet import *
12
from .byobnet import *
23
from .cspnet import *
34
from .densenet import *
@@ -39,5 +40,5 @@
3940
from .layers import TestTimePoolHead, apply_test_time_pool
4041
from .layers import convert_splitbn_model
4142
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
42-
from .registry import *
43-
43+
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
44+
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained

0 commit comments

Comments
 (0)