Skip to content

Commit a806147

Browse files
committed
Add torch.export test
1 parent ff278c9 commit a806147

File tree

6 files changed

+91
-1
lines changed

6 files changed

+91
-1
lines changed

.github/workflows/tests.yml

+15
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ jobs:
9090
- name: Test with PyTest
9191
run: uv run pytest -v -rsx -n 2 -m "compile"
9292

93+
test_torch_export:
94+
runs-on: ubuntu-latest
95+
steps:
96+
- uses: actions/checkout@v4
97+
- name: Set up Python
98+
uses: astral-sh/setup-uv@v5
99+
with:
100+
python-version: "3.10"
101+
- name: Install dependencies
102+
run: uv pip install -r requirements/required.txt -r requirements/test.txt
103+
- name: Show installed packages
104+
run: uv pip list
105+
- name: Test with PyTest
106+
run: uv run pytest -v -rsx -n 2 -m "torch_export"
107+
93108
minimum:
94109
runs-on: ubuntu-latest
95110
steps:

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*']
6565
markers = [
6666
"logits_match",
6767
"compile",
68+
"torch_export",
6869
]
6970

7071
[tool.coverage.run]

segmentation_models_pytorch/base/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from . import initialization as init
55
from .hub_mixin import SMPHubMixin
6+
from .utils import is_torch_compiling
67

78
T = TypeVar("T", bound="SegmentationModel")
89

@@ -50,7 +51,11 @@ def check_input_shape(self, x):
5051
def forward(self, x):
5152
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
5253

53-
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
54+
if (
55+
not torch.jit.is_tracing()
56+
and not is_torch_compiling()
57+
and self.requires_divisible_input_shape
58+
):
5459
self.check_input_shape(x)
5560

5661
features = self.encoder(x)
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
4+
def is_torch_compiling():
5+
try:
6+
return torch.compiler.is_compiling()
7+
except Exception:
8+
try:
9+
import torch._dynamo as dynamo # noqa: F401
10+
11+
return dynamo.is_compiling()
12+
except Exception:
13+
return False

tests/encoders/base.py

+28
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,31 @@ def test_compile(self):
231231

232232
with torch.inference_mode():
233233
compiled_encoder(sample)
234+
235+
@pytest.mark.torch_export
236+
def test_torch_export(self):
237+
if not check_run_test_on_diff_or_main(self.files_for_diff):
238+
self.skipTest("No diff and not on `main`.")
239+
240+
sample = self._get_sample(
241+
batch_size=self.default_batch_size,
242+
num_channels=self.default_num_channels,
243+
height=self.default_height,
244+
width=self.default_width,
245+
).to(default_device)
246+
247+
encoder = self.get_tiny_encoder()
248+
encoder = encoder.eval().to(default_device)
249+
250+
exported_encoder = torch.export.export(
251+
encoder,
252+
args=(sample,),
253+
strict=True,
254+
)
255+
256+
with torch.inference_mode():
257+
eager_output = encoder(sample)
258+
exported_output = exported_encoder.module().forward(sample)
259+
260+
for eager_feature, exported_feature in zip(eager_output, exported_output):
261+
torch.testing.assert_close(eager_feature, exported_feature)

tests/models/base.py

+28
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,31 @@ def test_compile(self):
254254

255255
with torch.inference_mode():
256256
compiled_model(sample)
257+
258+
@pytest.mark.torch_export
259+
def test_torch_export(self):
260+
if not check_run_test_on_diff_or_main(self.files_for_diff):
261+
self.skipTest("No diff and not on `main`.")
262+
263+
sample = self._get_sample(
264+
batch_size=self.default_batch_size,
265+
num_channels=self.default_num_channels,
266+
height=self.default_height,
267+
width=self.default_width,
268+
).to(default_device)
269+
270+
model = self.get_default_model()
271+
model.eval()
272+
273+
exported_model = torch.export.export(
274+
model,
275+
args=(sample,),
276+
strict=True,
277+
)
278+
279+
with torch.inference_mode():
280+
eager_output = model(sample)
281+
exported_output = exported_model.module().forward(sample)
282+
283+
self.assertEqual(eager_output.shape, exported_output.shape)
284+
torch.testing.assert_close(eager_output, exported_output)

0 commit comments

Comments
 (0)