Skip to content

Commit 38cb944

Browse files
committed
Tests
1 parent 9518964 commit 38cb944

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

tests/models/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class BaseModelTester(unittest.TestCase):
3333
default_height = 64
3434
default_width = 64
3535

36+
compile_dynamic = True
37+
3638
@property
3739
def model_type(self):
3840
if self.test_model_type is None:
@@ -232,16 +234,16 @@ def test_compile(self):
232234
model = model.eval().to(default_device)
233235

234236
if not model._is_torch_compilable:
235-
with self.assertRaises(RuntimeError):
237+
with self.assertRaises((RuntimeError)):
236238
torch.compiler.reset()
237239
compiled_model = torch.compile(
238-
model, fullgraph=True, dynamic=True, backend="eager"
240+
model, fullgraph=True, dynamic=self.compile_dynamic, backend="eager"
239241
)
240242
return
241243

242244
torch.compiler.reset()
243245
compiled_model = torch.compile(
244-
model, fullgraph=True, dynamic=True, backend="eager"
246+
model, fullgraph=True, dynamic=self.compile_dynamic, backend="eager"
245247
)
246248
with torch.inference_mode():
247249
compiled_model(sample)

tests/models/test_dpt.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@ class TestDPTModel(base.BaseModelTester):
2020
# should be overriden
2121
test_model_type = "dpt"
2222

23+
compile_dynamic = False
24+
2325
@property
2426
def hub_checkpoint(self):
25-
return "smp-hub/dpt-large-ade20k"
27+
return "smp-test-models/dpt-tu-test_vit"
2628

2729
@slow_test
2830
@requires_torch_greater_or_equal("2.0.1")
2931
@pytest.mark.logits_match
30-
def test_preserve_forward_output(self):
31-
model = smp.from_pretrained(self.hub_checkpoint).eval().to(default_device)
32+
def test_load_pretrained(self):
33+
hub_checkpoint = "smp-hub/dpt-large-ade20k"
34+
35+
model = smp.from_pretrained(hub_checkpoint)
36+
model = model.eval().to(default_device)
3237

3338
input_tensor = torch.ones((1, 3, 384, 384))
3439
input_tensor = input_tensor.to(default_device)

0 commit comments

Comments
 (0)