Skip to content

Commit 16f7aa9

Browse files
committed
Add default_cfg options for min_input_size / fixed_input_size, queries in model registry, and use for testing self-attn models
1 parent 4e4b863 commit 16f7aa9

File tree

5 files changed

+72
-16
lines changed

5 files changed

+72
-16
lines changed

tests/test_models.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import fnmatch
66

77
import timm
8-
from timm import list_models, create_model, set_scriptable
8+
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
9+
get_model_default_value
910

1011
if hasattr(torch._C, '_jit_set_profiling_executor'):
1112
# legacy executor is too slow to compile large models for unit tests
@@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
6061
model.eval()
6162

6263
input_size = model.default_cfg['input_size']
63-
if any([x > MAX_BWD_SIZE for x in input_size]):
64-
# cap backward test at 128 * 128 to keep resource usage down
65-
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
64+
if not is_model_default_key(model_name, 'fixed_input_size'):
65+
min_input_size = get_model_default_value(model_name, 'min_input_size')
66+
if min_input_size is not None:
67+
input_size = min_input_size
68+
else:
69+
if any([x > MAX_BWD_SIZE for x in input_size]):
70+
# cap backward test at 128 * 128 to keep resource usage down
71+
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
72+
6673
inputs = torch.randn((batch_size, *input_size))
6774
outputs = model(inputs)
6875
outputs.mean().backward()
@@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
155162
with set_scriptable(True):
156163
model = create_model(model_name, pretrained=False)
157164
model.eval()
158-
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
165+
166+
if has_model_default_key(model_name, 'fixed_input_size'):
167+
input_size = get_model_default_value(model_name, 'input_size')
168+
elif has_model_default_key(model_name, 'min_input_size'):
169+
input_size = get_model_default_value(model_name, 'min_input_size')
170+
else:
171+
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
172+
159173
model = torch.jit.script(model)
160174
outputs = model(torch.randn((batch_size, *input_size)))
161175

@@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
180194
model.eval()
181195
expected_channels = model.feature_info.channels()
182196
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
183-
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
197+
198+
if has_model_default_key(model_name, 'fixed_input_size'):
199+
input_size = get_model_default_value(model_name, 'input_size')
200+
elif has_model_default_key(model_name, 'min_input_size'):
201+
input_size = get_model_default_value(model_name, 'min_input_size')
202+
else:
203+
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
204+
184205
outputs = model(torch.randn((batch_size, *input_size)))
185206
assert len(expected_channels) == len(outputs)
186207
for e, o in zip(expected_channels, outputs):

timm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .version import __version__
22
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3-
is_scriptable, is_exportable, set_scriptable, set_exportable
3+
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
4+
get_model_default_value, is_model_pretrained

timm/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@
4040
from .layers import TestTimePoolHead, apply_test_time_pool
4141
from .layers import convert_splitbn_model
4242
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
43-
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules
43+
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
44+
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained

timm/models/byoanet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,24 @@ def _cfg(url='', **kwargs):
3737
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
3838
'crop_pct': 0.875, 'interpolation': 'bilinear',
3939
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
40-
'first_conv': 'stem.conv', 'classifier': 'head.fc',
40+
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
41+
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
4142
**kwargs
4243
}
4344

4445

4546
default_cfgs = {
4647
# GPU-Efficient (ResNet) weights
47-
'botnet50t_224': _cfg(url=''),
48-
'botnet50t_c4c5_224': _cfg(url=''),
48+
'botnet50t_224': _cfg(url='', fixed_input_size=True),
49+
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
4950

50-
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
51-
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
51+
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
52+
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
5253
'halonet26t': _cfg(url=''),
5354
'halonet50t': _cfg(url=''),
5455

55-
'lambda_resnet26t': _cfg(url=''),
56-
'lambda_resnet50t': _cfg(url=''),
56+
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
57+
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
5758
}
5859

5960

timm/models/registry.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
import re
77
import fnmatch
88
from collections import defaultdict
9+
from copy import deepcopy
910

10-
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
11+
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
12+
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
1113

1214
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
1315
_model_to_module = {} # mapping of model names to module names
1416
_model_entrypoints = {} # mapping of model names to entrypoint fns
1517
_model_has_pretrained = set() # set of model names that have pretrained weight url present
18+
_model_default_cfgs = dict() # central repo for model default_cfgs
1619

1720

1821
def register_model(fn):
@@ -37,6 +40,7 @@ def register_model(fn):
3740
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
3841
# entrypoints or non-matching combos
3942
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
43+
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
4044
if has_pretrained:
4145
_model_has_pretrained.add(model_name)
4246
return fn
@@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names):
105109
assert isinstance(module_names, (tuple, list, set))
106110
return any(model_name in _module_to_models[n] for n in module_names)
107111

112+
113+
def has_model_default_key(model_name, cfg_key):
114+
""" Query model default_cfgs for existence of a specific key.
115+
"""
116+
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
117+
return True
118+
return False
119+
120+
121+
def is_model_default_key(model_name, cfg_key):
122+
""" Return truthy value for specified model default_cfg key, False if does not exist.
123+
"""
124+
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
125+
return True
126+
return False
127+
128+
129+
def get_model_default_value(model_name, cfg_key):
130+
""" Get a specific model default_cfg value by key. None if it doesn't exist.
131+
"""
132+
if model_name in _model_default_cfgs:
133+
return _model_default_cfgs[model_name].get(cfg_key, None)
134+
else:
135+
return None
136+
137+
138+
def is_model_pretrained(model_name):
139+
return model_name in _model_has_pretrained

0 commit comments

Comments
 (0)