4
4
import segmentation_models_pytorch as smp
5
5
6
6
from functools import lru_cache
7
- from tests .utils import default_device , check_run_test_on_diff_or_main
7
+ from tests .utils import (
8
+ default_device ,
9
+ check_run_test_on_diff_or_main ,
10
+ requires_torch_greater_or_equal ,
11
+ )
8
12
9
13
10
14
class BaseEncoderTester (unittest .TestCase ):
@@ -29,11 +33,19 @@ class BaseEncoderTester(unittest.TestCase):
29
33
depth_to_test = [3 , 4 , 5 ]
30
34
strides_to_test = [8 , 16 ] # 32 is a default one
31
35
36
+ # enable/disable tests
37
+ do_test_torch_compile = True
38
+ do_test_torch_export = True
39
+
32
40
def get_tiny_encoder (self ):
33
41
return smp .encoders .get_encoder (self .encoder_names [0 ], encoder_weights = None )
34
42
35
43
@lru_cache
36
- def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
44
+ def _get_sample (self , batch_size = None , num_channels = None , height = None , width = None ):
45
+ batch_size = batch_size or self .default_batch_size
46
+ num_channels = num_channels or self .default_num_channels
47
+ height = height or self .default_height
48
+ width = width or self .default_width
37
49
return torch .rand (batch_size , num_channels , height , width )
38
50
39
51
def get_features_output_strides (self , sample , features ):
@@ -43,12 +55,7 @@ def get_features_output_strides(self, sample, features):
43
55
return height_strides , width_strides
44
56
45
57
def test_forward_backward (self ):
46
- sample = self ._get_sample (
47
- batch_size = self .default_batch_size ,
48
- num_channels = self .default_num_channels ,
49
- height = self .default_height ,
50
- width = self .default_width ,
51
- ).to (default_device )
58
+ sample = self ._get_sample ().to (default_device )
52
59
for encoder_name in self .encoder_names :
53
60
with self .subTest (encoder_name = encoder_name ):
54
61
# init encoder
@@ -75,12 +82,7 @@ def test_in_channels(self):
75
82
]
76
83
77
84
for encoder_name , in_channels in cases :
78
- sample = self ._get_sample (
79
- batch_size = self .default_batch_size ,
80
- num_channels = in_channels ,
81
- height = self .default_height ,
82
- width = self .default_width ,
83
- ).to (default_device )
85
+ sample = self ._get_sample (num_channels = in_channels ).to (default_device )
84
86
85
87
with self .subTest (encoder_name = encoder_name , in_channels = in_channels ):
86
88
encoder = smp .encoders .get_encoder (
@@ -93,12 +95,7 @@ def test_in_channels(self):
93
95
encoder .forward (sample )
94
96
95
97
def test_depth (self ):
96
- sample = self ._get_sample (
97
- batch_size = self .default_batch_size ,
98
- num_channels = self .default_num_channels ,
99
- height = self .default_height ,
100
- width = self .default_width ,
101
- ).to (default_device )
98
+ sample = self ._get_sample ().to (default_device )
102
99
103
100
cases = [
104
101
(encoder_name , depth )
@@ -157,12 +154,7 @@ def test_depth(self):
157
154
)
158
155
159
156
def test_dilated (self ):
160
- sample = self ._get_sample (
161
- batch_size = self .default_batch_size ,
162
- num_channels = self .default_num_channels ,
163
- height = self .default_height ,
164
- width = self .default_width ,
165
- ).to (default_device )
157
+ sample = self ._get_sample ().to (default_device )
166
158
167
159
cases = [
168
160
(encoder_name , stride )
@@ -216,15 +208,15 @@ def test_dilated(self):
216
208
217
209
@pytest .mark .compile
218
210
def test_compile (self ):
211
+ if not self .do_test_torch_compile :
212
+ self .skipTest (
213
+ f"torch_compile test is disabled for { self .encoder_names [0 ]} ."
214
+ )
215
+
219
216
if not check_run_test_on_diff_or_main (self .files_for_diff ):
220
217
self .skipTest ("No diff and not on `main`." )
221
218
222
- sample = self ._get_sample (
223
- batch_size = self .default_batch_size ,
224
- num_channels = self .default_num_channels ,
225
- height = self .default_height ,
226
- width = self .default_width ,
227
- ).to (default_device )
219
+ sample = self ._get_sample ().to (default_device )
228
220
229
221
encoder = self .get_tiny_encoder ().eval ().to (default_device )
230
222
compiled_encoder = torch .compile (encoder , fullgraph = True , dynamic = True )
@@ -233,16 +225,15 @@ def test_compile(self):
233
225
compiled_encoder (sample )
234
226
235
227
@pytest .mark .torch_export
228
+ @requires_torch_greater_or_equal ("2.4.0" )
236
229
def test_torch_export (self ):
230
+ if not self .do_test_torch_export :
231
+ self .skipTest (f"torch_export test is disabled for { self .encoder_names [0 ]} ." )
232
+
237
233
if not check_run_test_on_diff_or_main (self .files_for_diff ):
238
234
self .skipTest ("No diff and not on `main`." )
239
235
240
- sample = self ._get_sample (
241
- batch_size = self .default_batch_size ,
242
- num_channels = self .default_num_channels ,
243
- height = self .default_height ,
244
- width = self .default_width ,
245
- ).to (default_device )
236
+ sample = self ._get_sample ().to (default_device )
246
237
247
238
encoder = self .get_tiny_encoder ()
248
239
encoder = encoder .eval ().to (default_device )
0 commit comments