File tree Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,8 @@ class BaseModelTester(unittest.TestCase):
33
33
default_height = 64
34
34
default_width = 64
35
35
36
+ compile_dynamic = True
37
+
36
38
@property
37
39
def model_type (self ):
38
40
if self .test_model_type is None :
@@ -232,16 +234,16 @@ def test_compile(self):
232
234
model = model .eval ().to (default_device )
233
235
234
236
if not model ._is_torch_compilable :
235
- with self .assertRaises (RuntimeError ):
237
+ with self .assertRaises (( RuntimeError ) ):
236
238
torch .compiler .reset ()
237
239
compiled_model = torch .compile (
238
- model , fullgraph = True , dynamic = True , backend = "eager"
240
+ model , fullgraph = True , dynamic = self . compile_dynamic , backend = "eager"
239
241
)
240
242
return
241
243
242
244
torch .compiler .reset ()
243
245
compiled_model = torch .compile (
244
- model , fullgraph = True , dynamic = True , backend = "eager"
246
+ model , fullgraph = True , dynamic = self . compile_dynamic , backend = "eager"
245
247
)
246
248
with torch .inference_mode ():
247
249
compiled_model (sample )
Original file line number Diff line number Diff line change @@ -20,15 +20,20 @@ class TestDPTModel(base.BaseModelTester):
20
20
# should be overriden
21
21
test_model_type = "dpt"
22
22
23
+ compile_dynamic = False
24
+
23
25
@property
24
26
def hub_checkpoint (self ):
25
- return "smp-hub /dpt-large-ade20k "
27
+ return "smp-test-models /dpt-tu-test_vit "
26
28
27
29
@slow_test
28
30
@requires_torch_greater_or_equal ("2.0.1" )
29
31
@pytest .mark .logits_match
30
- def test_preserve_forward_output (self ):
31
- model = smp .from_pretrained (self .hub_checkpoint ).eval ().to (default_device )
32
+ def test_load_pretrained (self ):
33
+ hub_checkpoint = "smp-hub/dpt-large-ade20k"
34
+
35
+ model = smp .from_pretrained (hub_checkpoint )
36
+ model = model .eval ().to (default_device )
32
37
33
38
input_tensor = torch .ones ((1 , 3 , 384 , 384 ))
34
39
input_tensor = input_tensor .to (default_device )
You can’t perform that action at this time.
0 commit comments