5
5
import fnmatch
6
6
7
7
import timm
8
- from timm import list_models , create_model , set_scriptable
8
+ from timm import list_models , create_model , set_scriptable , has_model_default_key , is_model_default_key , \
9
+ get_model_default_value
9
10
10
11
if hasattr (torch ._C , '_jit_set_profiling_executor' ):
11
12
# legacy executor is too slow to compile large models for unit tests
@@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
60
61
model .eval ()
61
62
62
63
input_size = model .default_cfg ['input_size' ]
63
- if any ([x > MAX_BWD_SIZE for x in input_size ]):
64
- # cap backward test at 128 * 128 to keep resource usage down
65
- input_size = tuple ([min (x , MAX_BWD_SIZE ) for x in input_size ])
64
+ if not is_model_default_key (model_name , 'fixed_input_size' ):
65
+ min_input_size = get_model_default_value (model_name , 'min_input_size' )
66
+ if min_input_size is not None :
67
+ input_size = min_input_size
68
+ else :
69
+ if any ([x > MAX_BWD_SIZE for x in input_size ]):
70
+ # cap backward test at 128 * 128 to keep resource usage down
71
+ input_size = tuple ([min (x , MAX_BWD_SIZE ) for x in input_size ])
72
+
66
73
inputs = torch .randn ((batch_size , * input_size ))
67
74
outputs = model (inputs )
68
75
outputs .mean ().backward ()
@@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
155
162
with set_scriptable (True ):
156
163
model = create_model (model_name , pretrained = False )
157
164
model .eval ()
158
- input_size = (3 , 128 , 128 ) # jit compile is already a bit slow and we've tested normal res already...
165
+
166
+ if has_model_default_key (model_name , 'fixed_input_size' ):
167
+ input_size = get_model_default_value (model_name , 'input_size' )
168
+ elif has_model_default_key (model_name , 'min_input_size' ):
169
+ input_size = get_model_default_value (model_name , 'min_input_size' )
170
+ else :
171
+ input_size = (3 , 128 , 128 ) # jit compile is already a bit slow and we've tested normal res already...
172
+
159
173
model = torch .jit .script (model )
160
174
outputs = model (torch .randn ((batch_size , * input_size )))
161
175
@@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
180
194
model .eval ()
181
195
expected_channels = model .feature_info .channels ()
182
196
assert len (expected_channels ) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
183
- input_size = (3 , 96 , 96 ) # jit compile is already a bit slow and we've tested normal res already...
197
+
198
+ if has_model_default_key (model_name , 'fixed_input_size' ):
199
+ input_size = get_model_default_value (model_name , 'input_size' )
200
+ elif has_model_default_key (model_name , 'min_input_size' ):
201
+ input_size = get_model_default_value (model_name , 'min_input_size' )
202
+ else :
203
+ input_size = (3 , 96 , 96 ) # jit compile is already a bit slow and we've tested normal res already...
204
+
184
205
outputs = model (torch .randn ((batch_size , * input_size )))
185
206
assert len (expected_channels ) == len (outputs )
186
207
for e , o in zip (expected_channels , outputs ):
0 commit comments