1
1
from tests .encoders import base
2
2
import timm
3
3
import torch
4
- import segmentation_models_pytorch as smp
5
4
import pytest
5
+ from segmentation_models_pytorch .encoders import TimmViTEncoder
6
+ from segmentation_models_pytorch .encoders .timm_vit import sample_block_indices_uniformly
6
7
7
8
from tests .utils import (
8
9
default_device ,
11
12
requires_timm_greater_or_equal ,
12
13
)
13
14
14
- timm_vit_encoders = [
15
- "tu-vit_tiny_patch16_224" ,
16
- "tu-vit_small_patch32_224" ,
17
- "tu-vit_base_patch32_384" ,
18
- "tu-vit_base_patch16_gap_224" ,
19
- "tu-vit_medium_patch16_reg4_gap_256" ,
20
- "tu-vit_so150m2_patch16_reg1_gap_256" ,
21
- "tu-vit_medium_patch16_gap_240" ,
22
- ]
15
+ timm_vit_encoders = ["vit_tiny_patch16_224" ]
23
16
24
17
25
18
class TestTimmViTEncoders (base .BaseEncoderTester ):
26
19
encoder_names = timm_vit_encoders
27
20
tiny_encoder_patch_size = 224
21
+ default_height = 224
22
+ default_width = 224
28
23
29
24
files_for_diff = ["encoders/dpt.py" ]
30
25
@@ -35,14 +30,10 @@ class TestTimmViTEncoders(base.BaseEncoderTester):
35
30
36
31
depth_to_test = [2 , 3 , 4 ]
37
32
38
- default_encoder_kwargs = {"use_vit_encoder" : True }
39
-
40
- def _get_model_expected_input_shape (self , encoder_name : str ) -> int :
41
- patch_size_str = encoder_name [- 3 :]
42
- return int (patch_size_str )
33
+ default_encoder_kwargs = {"pretrained" : False }
43
34
44
35
def get_tiny_encoder (self ):
45
- return smp . encoders . get_encoder (
36
+ return TimmViTEncoder (
46
37
self .encoder_names [0 ],
47
38
encoder_weights = None ,
48
39
output_stride = None ,
@@ -55,13 +46,10 @@ def get_tiny_encoder(self):
55
46
@requires_timm_greater_or_equal ("1.0.15" )
56
47
def test_forward_backward (self ):
57
48
for encoder_name in self .encoder_names :
58
- patch_size = self ._get_model_expected_input_shape (encoder_name )
59
- sample = self ._get_sample (height = patch_size , width = patch_size ).to (
60
- default_device
61
- )
49
+ sample = self ._get_sample ().to (default_device )
62
50
with self .subTest (encoder_name = encoder_name ):
63
51
# init encoder
64
- encoder = smp . encoders . get_encoder (
52
+ encoder = TimmViTEncoder (
65
53
encoder_name ,
66
54
in_channels = 3 ,
67
55
encoder_weights = None ,
@@ -90,13 +78,10 @@ def test_in_channels(self):
90
78
]
91
79
92
80
for encoder_name , in_channels in cases :
93
- patch_size = self ._get_model_expected_input_shape (encoder_name )
94
- sample = self ._get_sample (
95
- height = patch_size , width = patch_size , num_channels = in_channels
96
- ).to (default_device )
81
+ sample = self ._get_sample (num_channels = in_channels ).to (default_device )
97
82
98
83
with self .subTest (encoder_name = encoder_name , in_channels = in_channels ):
99
- encoder = smp . encoders . get_encoder (
84
+ encoder = TimmViTEncoder (
100
85
encoder_name ,
101
86
in_channels = in_channels ,
102
87
encoder_weights = None ,
@@ -119,12 +104,9 @@ def test_depth(self):
119
104
]
120
105
121
106
for encoder_name , depth in cases :
122
- patch_size = self ._get_model_expected_input_shape (encoder_name )
123
- sample = self ._get_sample (height = patch_size , width = patch_size ).to (
124
- default_device
125
- )
107
+ sample = self ._get_sample ().to (default_device )
126
108
with self .subTest (encoder_name = encoder_name , depth = depth ):
127
- encoder = smp . encoders . get_encoder (
109
+ encoder = TimmViTEncoder (
128
110
encoder_name ,
129
111
in_channels = self .default_num_channels ,
130
112
encoder_weights = None ,
@@ -150,10 +132,9 @@ def test_depth(self):
150
132
sample , features
151
133
)
152
134
153
- timm_encoder_name = encoder_name [3 :]
154
- encoder_out_indices = encoder .out_indices
135
+ encoder_out_indices = sample_block_indices_uniformly (depth , 12 )
155
136
timm_model_feature_info = timm .create_model (
156
- model_name = timm_encoder_name
137
+ model_name = encoder_name
157
138
).feature_info
158
139
feature_info_obj = timm .models .FeatureInfo (
159
140
feature_info = timm_model_feature_info ,
@@ -189,35 +170,56 @@ def test_depth(self):
189
170
@requires_timm_greater_or_equal ("1.0.15" )
190
171
def test_invalid_depth (self ):
191
172
with self .assertRaises (ValueError ):
192
- smp .encoders .get_encoder (self .encoder_names [0 ], depth = 5 , output_stride = None )
173
+ TimmViTEncoder (
174
+ self .encoder_names [0 ],
175
+ depth = 5 ,
176
+ output_stride = None ,
177
+ ** self .default_encoder_kwargs ,
178
+ )
193
179
with self .assertRaises (ValueError ):
194
- smp .encoders .get_encoder (self .encoder_names [0 ], depth = 0 , output_stride = None )
180
+ TimmViTEncoder (
181
+ self .encoder_names [0 ],
182
+ depth = 0 ,
183
+ output_stride = None ,
184
+ ** self .default_encoder_kwargs ,
185
+ )
195
186
196
187
@requires_timm_greater_or_equal ("1.0.15" )
197
188
def test_invalid_out_indices (self ):
198
189
with self .assertRaises (ValueError ):
199
- smp .encoders .get_encoder (
200
- self .encoder_names [0 ], output_stride = None , out_indices = - 1
190
+ TimmViTEncoder (
191
+ self .encoder_names [0 ],
192
+ output_stride = None ,
193
+ output_indices = - 25 ,
194
+ ** self .default_encoder_kwargs ,
201
195
)
202
196
203
197
with self .assertRaises (ValueError ):
204
- smp .encoders .get_encoder (
205
- self .encoder_names [0 ], output_stride = None , out_indices = [1 , 2 , 25 ]
198
+ TimmViTEncoder (
199
+ self .encoder_names [0 ],
200
+ output_stride = None ,
201
+ output_indices = [1 , 2 , 25 ],
202
+ ** self .default_encoder_kwargs ,
206
203
)
207
204
208
205
@requires_timm_greater_or_equal ("1.0.15" )
209
206
def test_invalid_out_indices_length (self ):
210
207
with self .assertRaises (ValueError ):
211
- smp .encoders .get_encoder (
212
- self .encoder_names [0 ], output_stride = None , out_indices = 2 , depth = 2
208
+ TimmViTEncoder (
209
+ self .encoder_names [0 ],
210
+ output_stride = None ,
211
+ output_indices = 2 ,
212
+ depth = 2 ,
213
+ ** self .default_encoder_kwargs ,
213
214
)
214
215
215
216
with self .assertRaises (ValueError ):
216
- smp . encoders . get_encoder (
217
+ TimmViTEncoder (
217
218
self .encoder_names [0 ],
218
219
output_stride = None ,
219
- out_indices = [0 , 1 , 2 , 3 , 4 ],
220
+ output_indices = [0 , 1 , 2 , 3 , 4 ],
220
221
depth = 4 ,
222
+ ** self .default_encoder_kwargs ,
221
223
)
222
224
223
225
@requires_timm_greater_or_equal ("1.0.15" )
@@ -235,23 +237,19 @@ def test_dilated(self):
235
237
ValueError , msg = "Dilated mode not supported, set output stride to None"
236
238
):
237
239
encoder_name , stride = cases [0 ]
238
- patch_size = self ._get_model_expected_input_shape (encoder_name )
239
- sample = self ._get_sample (height = patch_size , width = patch_size ).to (
240
- default_device
241
- )
242
- encoder = smp .encoders .get_encoder (
240
+ sample = self ._get_sample ().to (default_device )
241
+ encoder = TimmViTEncoder (
243
242
encoder_name ,
244
243
in_channels = self .default_num_channels ,
245
244
encoder_weights = None ,
246
245
output_stride = stride ,
247
246
depth = self .default_depth ,
248
- ** self .default_encoder_kwargs ,
249
247
).to (default_device )
250
248
return
251
249
252
250
for encoder_name , stride in cases :
253
251
with self .subTest (encoder_name = encoder_name , stride = stride ):
254
- encoder = smp . encoders . get_encoder (
252
+ encoder = TimmViTEncoder (
255
253
encoder_name ,
256
254
in_channels = self .default_num_channels ,
257
255
encoder_weights = None ,
0 commit comments