File tree 6 files changed +91
-1
lines changed
segmentation_models_pytorch/base
6 files changed +91
-1
lines changed Original file line number Diff line number Diff line change 90
90
- name : Test with PyTest
91
91
run : uv run pytest -v -rsx -n 2 -m "compile"
92
92
93
+ test_torch_export :
94
+ runs-on : ubuntu-latest
95
+ steps :
96
+ - uses : actions/checkout@v4
97
+ - name : Set up Python
98
+ uses : astral-sh/setup-uv@v5
99
+ with :
100
+ python-version : " 3.10"
101
+ - name : Install dependencies
102
+ run : uv pip install -r requirements/required.txt -r requirements/test.txt
103
+ - name : Show installed packages
104
+ run : uv pip list
105
+ - name : Test with PyTest
106
+ run : uv run pytest -v -rsx -n 2 -m "torch_export"
107
+
93
108
minimum :
94
109
runs-on : ubuntu-latest
95
110
steps :
Original file line number Diff line number Diff line change @@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*']
65
65
markers = [
66
66
" logits_match" ,
67
67
" compile" ,
68
+ " torch_export" ,
68
69
]
69
70
70
71
[tool .coverage .run ]
Original file line number Diff line number Diff line change 3
3
4
4
from . import initialization as init
5
5
from .hub_mixin import SMPHubMixin
6
+ from .utils import is_torch_compiling
6
7
7
8
T = TypeVar ("T" , bound = "SegmentationModel" )
8
9
@@ -50,7 +51,11 @@ def check_input_shape(self, x):
50
51
def forward (self , x ):
51
52
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
52
53
53
- if not torch .jit .is_tracing () and self .requires_divisible_input_shape :
54
+ if (
55
+ not torch .jit .is_tracing ()
56
+ and not is_torch_compiling ()
57
+ and self .requires_divisible_input_shape
58
+ ):
54
59
self .check_input_shape (x )
55
60
56
61
features = self .encoder (x )
Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+
4
+ def is_torch_compiling ():
5
+ try :
6
+ return torch .compiler .is_compiling ()
7
+ except Exception :
8
+ try :
9
+ import torch ._dynamo as dynamo # noqa: F401
10
+
11
+ return dynamo .is_compiling ()
12
+ except Exception :
13
+ return False
Original file line number Diff line number Diff line change @@ -231,3 +231,31 @@ def test_compile(self):
231
231
232
232
with torch .inference_mode ():
233
233
compiled_encoder (sample )
234
+
235
+ @pytest .mark .torch_export
236
+ def test_torch_export (self ):
237
+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
238
+ self .skipTest ("No diff and not on `main`." )
239
+
240
+ sample = self ._get_sample (
241
+ batch_size = self .default_batch_size ,
242
+ num_channels = self .default_num_channels ,
243
+ height = self .default_height ,
244
+ width = self .default_width ,
245
+ ).to (default_device )
246
+
247
+ encoder = self .get_tiny_encoder ()
248
+ encoder = encoder .eval ().to (default_device )
249
+
250
+ exported_encoder = torch .export .export (
251
+ encoder ,
252
+ args = (sample ,),
253
+ strict = True ,
254
+ )
255
+
256
+ with torch .inference_mode ():
257
+ eager_output = encoder (sample )
258
+ exported_output = exported_encoder .module ().forward (sample )
259
+
260
+ for eager_feature , exported_feature in zip (eager_output , exported_output ):
261
+ torch .testing .assert_close (eager_feature , exported_feature )
Original file line number Diff line number Diff line change @@ -254,3 +254,31 @@ def test_compile(self):
254
254
255
255
with torch .inference_mode ():
256
256
compiled_model (sample )
257
+
258
+ @pytest .mark .torch_export
259
+ def test_torch_export (self ):
260
+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
261
+ self .skipTest ("No diff and not on `main`." )
262
+
263
+ sample = self ._get_sample (
264
+ batch_size = self .default_batch_size ,
265
+ num_channels = self .default_num_channels ,
266
+ height = self .default_height ,
267
+ width = self .default_width ,
268
+ ).to (default_device )
269
+
270
+ model = self .get_default_model ()
271
+ model .eval ()
272
+
273
+ exported_model = torch .export .export (
274
+ model ,
275
+ args = (sample ,),
276
+ strict = True ,
277
+ )
278
+
279
+ with torch .inference_mode ():
280
+ eager_output = model (sample )
281
+ exported_output = exported_model .module ().forward (sample )
282
+
283
+ self .assertEqual (eager_output .shape , exported_output .shape )
284
+ torch .testing .assert_close (eager_output , exported_output )
You can’t perform that action at this time.
0 commit comments