Skip to content

Commit aa5b088

Browse files
committed
Disable export tests for dpn and inceptionv4
1 parent a806147 commit aa5b088

File tree

3 files changed

+41
-42
lines changed

3 files changed

+41
-42
lines changed

tests/encoders/base.py

+29-38
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import segmentation_models_pytorch as smp
55

66
from functools import lru_cache
7-
from tests.utils import default_device, check_run_test_on_diff_or_main
7+
from tests.utils import (
8+
default_device,
9+
check_run_test_on_diff_or_main,
10+
requires_torch_greater_or_equal,
11+
)
812

913

1014
class BaseEncoderTester(unittest.TestCase):
@@ -29,11 +33,19 @@ class BaseEncoderTester(unittest.TestCase):
2933
depth_to_test = [3, 4, 5]
3034
strides_to_test = [8, 16] # 32 is a default one
3135

36+
# enable/disable tests
37+
do_test_torch_compile = True
38+
do_test_torch_export = True
39+
3240
def get_tiny_encoder(self):
3341
return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None)
3442

3543
@lru_cache
36-
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
44+
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None):
45+
batch_size = batch_size or self.default_batch_size
46+
num_channels = num_channels or self.default_num_channels
47+
height = height or self.default_height
48+
width = width or self.default_width
3749
return torch.rand(batch_size, num_channels, height, width)
3850

3951
def get_features_output_strides(self, sample, features):
@@ -43,12 +55,7 @@ def get_features_output_strides(self, sample, features):
4355
return height_strides, width_strides
4456

4557
def test_forward_backward(self):
46-
sample = self._get_sample(
47-
batch_size=self.default_batch_size,
48-
num_channels=self.default_num_channels,
49-
height=self.default_height,
50-
width=self.default_width,
51-
).to(default_device)
58+
sample = self._get_sample().to(default_device)
5259
for encoder_name in self.encoder_names:
5360
with self.subTest(encoder_name=encoder_name):
5461
# init encoder
@@ -75,12 +82,7 @@ def test_in_channels(self):
7582
]
7683

7784
for encoder_name, in_channels in cases:
78-
sample = self._get_sample(
79-
batch_size=self.default_batch_size,
80-
num_channels=in_channels,
81-
height=self.default_height,
82-
width=self.default_width,
83-
).to(default_device)
85+
sample = self._get_sample(num_channels=in_channels).to(default_device)
8486

8587
with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
8688
encoder = smp.encoders.get_encoder(
@@ -93,12 +95,7 @@ def test_in_channels(self):
9395
encoder.forward(sample)
9496

9597
def test_depth(self):
96-
sample = self._get_sample(
97-
batch_size=self.default_batch_size,
98-
num_channels=self.default_num_channels,
99-
height=self.default_height,
100-
width=self.default_width,
101-
).to(default_device)
98+
sample = self._get_sample().to(default_device)
10299

103100
cases = [
104101
(encoder_name, depth)
@@ -157,12 +154,7 @@ def test_depth(self):
157154
)
158155

159156
def test_dilated(self):
160-
sample = self._get_sample(
161-
batch_size=self.default_batch_size,
162-
num_channels=self.default_num_channels,
163-
height=self.default_height,
164-
width=self.default_width,
165-
).to(default_device)
157+
sample = self._get_sample().to(default_device)
166158

167159
cases = [
168160
(encoder_name, stride)
@@ -216,15 +208,15 @@ def test_dilated(self):
216208

217209
@pytest.mark.compile
218210
def test_compile(self):
211+
if not self.do_test_torch_compile:
212+
self.skipTest(
213+
f"torch_compile test is disabled for {self.encoder_names[0]}."
214+
)
215+
219216
if not check_run_test_on_diff_or_main(self.files_for_diff):
220217
self.skipTest("No diff and not on `main`.")
221218

222-
sample = self._get_sample(
223-
batch_size=self.default_batch_size,
224-
num_channels=self.default_num_channels,
225-
height=self.default_height,
226-
width=self.default_width,
227-
).to(default_device)
219+
sample = self._get_sample().to(default_device)
228220

229221
encoder = self.get_tiny_encoder().eval().to(default_device)
230222
compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True)
@@ -233,16 +225,15 @@ def test_compile(self):
233225
compiled_encoder(sample)
234226

235227
@pytest.mark.torch_export
228+
@requires_torch_greater_or_equal("2.4.0")
236229
def test_torch_export(self):
230+
if not self.do_test_torch_export:
231+
self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.")
232+
237233
if not check_run_test_on_diff_or_main(self.files_for_diff):
238234
self.skipTest("No diff and not on `main`.")
239235

240-
sample = self._get_sample(
241-
batch_size=self.default_batch_size,
242-
num_channels=self.default_num_channels,
243-
height=self.default_height,
244-
width=self.default_width,
245-
).to(default_device)
236+
sample = self._get_sample().to(default_device)
246237

247238
encoder = self.get_tiny_encoder()
248239
encoder = encoder.eval().to(default_device)

tests/encoders/test_pretrainedmodels_encoders.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class TestDPNEncoder(base.BaseEncoderTester):
1212
)
1313
files_for_diff = ["encoders/dpn.py"]
1414

15+
# works with torch 2.4.0, but not with torch 2.5.1
16+
# dynamo error, probably on Sequential + OrderedDict
17+
do_test_torch_export = False
18+
1519
def get_tiny_encoder(self):
1620
params = {
1721
"stage_idxs": (2, 3, 4, 5),
@@ -29,17 +33,21 @@ def get_tiny_encoder(self):
2933

3034

3135
class TestInceptionResNetV2Encoder(base.BaseEncoderTester):
32-
supports_dilated = False
3336
encoder_names = (
3437
["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"]
3538
)
3639
files_for_diff = ["encoders/inceptionresnetv2.py"]
40+
supports_dilated = False
3741

3842

3943
class TestInceptionV4Encoder(base.BaseEncoderTester):
40-
supports_dilated = False
4144
encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"]
4245
files_for_diff = ["encoders/inceptionv4.py"]
46+
supports_dilated = False
47+
48+
# works with torch 2.4.0, but not with torch 2.5.1
49+
# dynamo error, probably on Sequential + OrderedDict
50+
do_test_torch_export = False
4351

4452

4553
class TestSeNetEncoder(base.BaseEncoderTester):

tests/encoders/test_smp_encoders.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ class TestEfficientNetEncoder(base.BaseEncoderTester):
6363
)
6464
files_for_diff = ["encoders/efficientnet.py"]
6565

66-
def test_compile(self):
67-
self.skipTest("compile fullgraph is not supported for efficientnet encoders")
66+
# torch_compile is not supported for efficientnet encoders
67+
do_test_torch_compile = False

0 commit comments

Comments
 (0)