1
+ from tests .encoders import base
2
+ import timm
3
+ import torch
4
+ import segmentation_models_pytorch as smp
5
+ import pytest
6
+
7
+ from tests .utils import (
8
+ default_device ,
9
+ check_run_test_on_diff_or_main ,
10
+ requires_torch_greater_or_equal ,
11
+ )
12
+
13
+ timm_vit_encoders = ["tu-vit_tiny_patch16_224" ,
14
+ "tu-vit_small_patch32_224" ,
15
+ "tu-vit_base_patch32_384" ,
16
+ "tu-vit_base_patch32_siglip_256" ,
17
+ ]
18
+
19
+ class TestTimmViTEncoders (base .BaseEncoderTester ):
20
+ encoder_names = timm_vit_encoders
21
+ tiny_encoder_patch_size = 224
22
+
23
+ files_for_diff = ["encoders/dpt.py" ]
24
+
25
+ num_output_features = 4
26
+ default_depth = 4
27
+ output_strides = None
28
+ supports_dilated = False
29
+
30
+ depth_to_test = [2 ,3 ,4 ]
31
+
32
+ default_encoder_kwargs = {"use_vit_encoder" : True }
33
+
34
+ def _get_model_expected_input_shape (self ,encoder_name : str ) -> int :
35
+ patch_size_str = encoder_name [ - 3 : ]
36
+ return int (patch_size_str )
37
+
38
+ def get_tiny_encoder (self ):
39
+ return smp .encoders .get_encoder (self .encoder_names [0 ], encoder_weights = None ,output_stride = None ,** self .default_encoder_kwargs )
40
+
41
+ def test_forward_backward (self ):
42
+ for encoder_name in self .encoder_names :
43
+ patch_size = self ._get_model_expected_input_shape (encoder_name )
44
+ sample = self ._get_sample (height = patch_size , width = patch_size ).to (default_device )
45
+ with self .subTest (encoder_name = encoder_name ):
46
+ # init encoder
47
+ encoder = smp .encoders .get_encoder (
48
+ encoder_name , in_channels = 3 , encoder_weights = None ,depth = self .default_depth ,output_stride = None ,** self .default_encoder_kwargs ,
49
+
50
+ ).to (default_device )
51
+
52
+ # forward
53
+ features = encoder .forward (sample )
54
+ self .assertEqual (
55
+ len (features [0 ]),
56
+ self .num_output_features ,
57
+ f"Encoder `{ encoder_name } ` should have { self .num_output_features } output feature maps, but has { len (features )} " ,
58
+ )
59
+
60
+ # backward
61
+ features [0 ][- 1 ].mean ().backward ()
62
+
63
+ def test_in_channels (self ):
64
+ cases = [
65
+ (encoder_name , in_channels )
66
+ for encoder_name in self .encoder_names
67
+ for in_channels in self .in_channels_to_test
68
+ ]
69
+
70
+ for encoder_name , in_channels in cases :
71
+ patch_size = self ._get_model_expected_input_shape (encoder_name )
72
+ sample = self ._get_sample (height = patch_size , width = patch_size ,num_channels = in_channels ).to (default_device )
73
+
74
+ with self .subTest (encoder_name = encoder_name , in_channels = in_channels ):
75
+ encoder = smp .encoders .get_encoder (
76
+ encoder_name , in_channels = in_channels , encoder_weights = None ,depth = 4 ,output_stride = None ,** self .default_encoder_kwargs
77
+ ).to (default_device )
78
+ encoder .eval ()
79
+
80
+ # forward
81
+ with torch .inference_mode ():
82
+ encoder .forward (sample )
83
+
84
+ def test_depth (self ):
85
+ cases = [
86
+ (encoder_name , depth )
87
+ for encoder_name in self .encoder_names
88
+ for depth in self .depth_to_test
89
+ ]
90
+
91
+ for encoder_name , depth in cases :
92
+ patch_size = self ._get_model_expected_input_shape (encoder_name )
93
+ sample = self ._get_sample (height = patch_size , width = patch_size ).to (default_device )
94
+ with self .subTest (encoder_name = encoder_name , depth = depth ):
95
+ encoder = smp .encoders .get_encoder (
96
+ encoder_name ,
97
+ in_channels = self .default_num_channels ,
98
+ encoder_weights = None ,
99
+ depth = depth ,
100
+ output_stride = None ,
101
+ ** self .default_encoder_kwargs
102
+ ).to (default_device )
103
+ encoder .eval ()
104
+
105
+ # forward
106
+ with torch .inference_mode ():
107
+ features = encoder .forward (sample )
108
+
109
+ # check number of features
110
+ self .assertEqual (
111
+ len (features [0 ]),
112
+ depth ,
113
+ f"Encoder `{ encoder_name } ` should have { depth } output feature maps, but has { len (features [0 ])} " ,
114
+ )
115
+
116
+ # check feature strides
117
+ height_strides , width_strides = self .get_features_output_strides (
118
+ sample , features [0 ]
119
+ )
120
+
121
+ timm_encoder_name = encoder_name [3 : ]
122
+ encoder_out_indices = encoder .out_indices
123
+ timm_model_feature_info = timm .create_model (model_name = timm_encoder_name ).feature_info
124
+ feature_info_obj = timm .models .FeatureInfo (feature_info = timm_model_feature_info ,out_indices = encoder_out_indices )
125
+ self .output_strides = feature_info_obj .reduction ()
126
+
127
+ self .assertEqual (
128
+ height_strides ,
129
+ self .output_strides [: depth ],
130
+ f"Encoder `{ encoder_name } ` should have output strides { self .output_strides [: depth ]} , but has { height_strides } " ,
131
+ )
132
+ self .assertEqual (
133
+ width_strides ,
134
+ self .output_strides [: depth ],
135
+ f"Encoder `{ encoder_name } ` should have output strides { self .output_strides [: depth ]} , but has { width_strides } " ,
136
+ )
137
+
138
+ # check encoder output stride property
139
+ self .assertEqual (
140
+ encoder .output_stride ,
141
+ self .output_strides [depth - 1 ],
142
+ f"Encoder `{ encoder_name } ` last feature map should have output stride { self .output_strides [depth - 1 ]} , but has { encoder .output_stride } " ,
143
+ )
144
+
145
+ # check out channels also have proper length
146
+ self .assertEqual (
147
+ len (encoder .out_channels ),
148
+ depth ,
149
+ f"Encoder `{ encoder_name } ` should have { depth } out_channels, but has { len (encoder .out_channels )} " ,
150
+ )
151
+
152
+ def test_invalid_depth (self ):
153
+ with self .assertRaises (ValueError ):
154
+ smp .encoders .get_encoder (self .encoder_names [0 ], depth = 5 ,output_stride = None )
155
+ with self .assertRaises (ValueError ):
156
+ smp .encoders .get_encoder (self .encoder_names [0 ], depth = 0 ,output_stride = None )
157
+
158
+ def test_dilated (self ):
159
+
160
+
161
+ cases = [
162
+ (encoder_name , stride )
163
+ for encoder_name in self .encoder_names
164
+ for stride in self .strides_to_test
165
+ ]
166
+
167
+ # special case for encoders that do not support dilated model
168
+ # just check proper error is raised
169
+ if not self .supports_dilated :
170
+ with self .assertRaises (ValueError , msg = "Dilated mode not supported, set output stride to None" ):
171
+ encoder_name , stride = cases [0 ]
172
+ patch_size = self ._get_model_expected_input_shape (encoder_name )
173
+ sample = self ._get_sample (height = patch_size , width = patch_size ).to (default_device )
174
+ encoder = smp .encoders .get_encoder (
175
+ encoder_name ,
176
+ in_channels = self .default_num_channels ,
177
+ encoder_weights = None ,
178
+ output_stride = stride ,
179
+ depth = self .default_depth ,
180
+ ** self .default_encoder_kwargs ,
181
+ ).to (default_device )
182
+ return
183
+
184
+ for encoder_name , stride in cases :
185
+ with self .subTest (encoder_name = encoder_name , stride = stride ):
186
+ encoder = smp .encoders .get_encoder (
187
+ encoder_name ,
188
+ in_channels = self .default_num_channels ,
189
+ encoder_weights = None ,
190
+ output_stride = stride ,
191
+ depth = self .default_depth ,
192
+ ** self .default_encoder_kwargs ,
193
+ ).to (default_device )
194
+ encoder .eval ()
195
+
196
+ # forward
197
+ with torch .inference_mode ():
198
+ features = encoder .forward (sample )
199
+
200
+ height_strides , width_strides = self .get_features_output_strides (
201
+ sample , features [0 ]
202
+ )
203
+ expected_height_strides = [min (stride , s ) for s in height_strides ]
204
+ expected_width_strides = [min (stride , s ) for s in width_strides ]
205
+
206
+ self .assertEqual (
207
+ height_strides ,
208
+ expected_height_strides ,
209
+ f"Encoder `{ encoder_name } ` should have height output strides { expected_height_strides } , but has { height_strides } " ,
210
+ )
211
+ self .assertEqual (
212
+ width_strides ,
213
+ expected_width_strides ,
214
+ f"Encoder `{ encoder_name } ` should have width output strides { expected_width_strides } , but has { width_strides } " ,
215
+ )
216
+
217
+ @pytest .mark .compile
218
+ def test_compile (self ):
219
+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
220
+ self .skipTest ("No diff and not on `main`." )
221
+
222
+ encoder = self .get_tiny_encoder ()
223
+ encoder = encoder .eval ().to (default_device )
224
+
225
+ sample = self ._get_sample (height = self .tiny_encoder_patch_size , width = self .tiny_encoder_patch_size ).to (default_device )
226
+
227
+ torch .compiler .reset ()
228
+ compiled_encoder = torch .compile (
229
+ encoder , fullgraph = True , dynamic = True , backend = "eager"
230
+ )
231
+
232
+ if encoder ._is_torch_compilable :
233
+ compiled_encoder (sample )
234
+ else :
235
+ with self .assertRaises (Exception ):
236
+ compiled_encoder (sample )
237
+
238
+ @pytest .mark .torch_export
239
+ @requires_torch_greater_or_equal ("2.4.0" )
240
+ def test_torch_export (self ):
241
+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
242
+ self .skipTest ("No diff and not on `main`." )
243
+
244
+ sample = self ._get_sample (height = self .tiny_encoder_patch_size , width = self .tiny_encoder_patch_size ).to (default_device )
245
+
246
+ encoder = self .get_tiny_encoder ()
247
+ encoder = encoder .eval ().to (default_device )
248
+
249
+ if not encoder ._is_torch_exportable :
250
+ with self .assertRaises (Exception ):
251
+ exported_encoder = torch .export .export (
252
+ encoder ,
253
+ args = (sample ,),
254
+ strict = True ,
255
+ )
256
+ return
257
+
258
+ exported_encoder = torch .export .export (
259
+ encoder ,
260
+ args = (sample ,),
261
+ strict = True ,
262
+ )
263
+
264
+ with torch .inference_mode ():
265
+ eager_output = encoder (sample )
266
+ exported_output = exported_encoder .module ().forward (sample )
267
+
268
+ for eager_feature , exported_feature in zip (eager_output , exported_output ):
269
+ torch .testing .assert_close (eager_feature , exported_feature )
270
+
271
+ @pytest .mark .torch_script
272
+ def test_torch_script (self ):
273
+ sample = self ._get_sample (height = self .tiny_encoder_patch_size , width = self .tiny_encoder_patch_size ).to (default_device )
274
+
275
+ encoder = self .get_tiny_encoder ()
276
+ encoder = encoder .eval ().to (default_device )
277
+
278
+ if not encoder ._is_torch_scriptable :
279
+ with self .assertRaises (RuntimeError , msg = "not torch scriptable" ):
280
+ scripted_encoder = torch .jit .script (encoder )
281
+ return
282
+
283
+ scripted_encoder = torch .jit .script (encoder )
284
+
285
+ with torch .inference_mode ():
286
+ eager_output = encoder (sample )
287
+ scripted_output = scripted_encoder (sample )
288
+
289
+ for eager_feature , scripted_feature in zip (eager_output , scripted_output ):
290
+ torch .testing .assert_close (eager_feature , scripted_feature )
291
+
292
+
293
+
294
+
295
+
296
+
0 commit comments