Skip to content

Commit c47bdfb

Browse files
Added intitial test and some minor code modifications
1 parent 5599409 commit c47bdfb

File tree

5 files changed

+585
-3
lines changed

5 files changed

+585
-3
lines changed

segmentation_models_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .decoders.pan import PAN
1515
from .decoders.upernet import UPerNet
1616
from .decoders.segformer import Segformer
17+
from .decoders.dpt import DPT
1718
from .base.hub_mixin import from_pretrained
1819

1920
from .__version__ import __version__
@@ -34,6 +35,7 @@
3435
PAN,
3536
UPerNet,
3637
Segformer,
38+
DPT
3739
]
3840
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
3941

@@ -84,6 +86,7 @@ def create_model(
8486
"PAN",
8587
"UPerNet",
8688
"Segformer",
89+
"DPT",
8790
"from_pretrained",
8891
"create_model",
8992
"__version__",

segmentation_models_pytorch/encoders/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
9292
in_channels=in_channels,
9393
depth=depth,
9494
pretrained=weights is not None,
95+
output_stride = output_stride,
9596
**kwargs,
9697
)
9798
return encoder

segmentation_models_pytorch/encoders/timm_vit.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ class TimmViTEncoder(nn.Module):
1111
ViT style models
1212
1313
Features:
14-
- Supports configurable depth and output stride.
15-
- Ensures consistent multi-level feature extraction across diverse models.
16-
- Compatible with convolutional and transformer-like backbones.
14+
- Supports configurable depth.
15+
- Ensures consistent multi-level feature extraction across all ViT models.
1716
"""
1817

1918
_is_torch_scriptable = True
@@ -50,6 +49,12 @@ def __init__(
5049
super().__init__()
5150
self.name = name
5251

52+
output_stride = kwargs.pop("output_stride",None)
53+
if output_stride is not None:
54+
raise ValueError(
55+
"Dilated mode not supported, set output stride to None"
56+
)
57+
5358
# Default model configuration for feature extraction
5459
common_kwargs = dict(
5560
in_chans=in_channels,
@@ -82,6 +87,9 @@ def __init__(
8287
int((model_num_blocks / 4) * index) - 1 for index in range(1, depth + 1)
8388
]
8489

90+
if isinstance(output_indices,int):
91+
output_indices = list(output_indices)
92+
8593
common_kwargs["out_indices"] = self.out_indices = output_indices
8694
feature_info_obj = timm.models.FeatureInfo(
8795
feature_info=feature_info, out_indices=output_indices
+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from tests.encoders import base
2+
import timm
3+
import torch
4+
import segmentation_models_pytorch as smp
5+
import pytest
6+
7+
from tests.utils import (
8+
default_device,
9+
check_run_test_on_diff_or_main,
10+
requires_torch_greater_or_equal,
11+
)
12+
13+
timm_vit_encoders = ["tu-vit_tiny_patch16_224",
14+
"tu-vit_small_patch32_224",
15+
"tu-vit_base_patch32_384",
16+
"tu-vit_base_patch32_siglip_256",
17+
]
18+
19+
class TestTimmViTEncoders(base.BaseEncoderTester):
20+
encoder_names = timm_vit_encoders
21+
tiny_encoder_patch_size = 224
22+
23+
files_for_diff = ["encoders/dpt.py"]
24+
25+
num_output_features = 4
26+
default_depth = 4
27+
output_strides = None
28+
supports_dilated = False
29+
30+
depth_to_test = [2,3,4]
31+
32+
default_encoder_kwargs = {"use_vit_encoder" : True}
33+
34+
def _get_model_expected_input_shape(self,encoder_name : str) -> int:
35+
patch_size_str = encoder_name[ -3 : ]
36+
return int(patch_size_str)
37+
38+
def get_tiny_encoder(self):
39+
return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None,output_stride = None,**self.default_encoder_kwargs)
40+
41+
def test_forward_backward(self):
42+
for encoder_name in self.encoder_names:
43+
patch_size = self._get_model_expected_input_shape(encoder_name)
44+
sample = self._get_sample(height = patch_size, width = patch_size).to(default_device)
45+
with self.subTest(encoder_name=encoder_name):
46+
# init encoder
47+
encoder = smp.encoders.get_encoder(
48+
encoder_name, in_channels=3, encoder_weights=None,depth = self.default_depth,output_stride = None,**self.default_encoder_kwargs,
49+
50+
).to(default_device)
51+
52+
# forward
53+
features = encoder.forward(sample)
54+
self.assertEqual(
55+
len(features[0]),
56+
self.num_output_features,
57+
f"Encoder `{encoder_name}` should have {self.num_output_features} output feature maps, but has {len(features)}",
58+
)
59+
60+
# backward
61+
features[0][-1].mean().backward()
62+
63+
def test_in_channels(self):
64+
cases = [
65+
(encoder_name, in_channels)
66+
for encoder_name in self.encoder_names
67+
for in_channels in self.in_channels_to_test
68+
]
69+
70+
for encoder_name, in_channels in cases:
71+
patch_size = self._get_model_expected_input_shape(encoder_name)
72+
sample = self._get_sample(height = patch_size, width = patch_size,num_channels=in_channels).to(default_device)
73+
74+
with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
75+
encoder = smp.encoders.get_encoder(
76+
encoder_name, in_channels=in_channels, encoder_weights=None,depth =4,output_stride = None,**self.default_encoder_kwargs
77+
).to(default_device)
78+
encoder.eval()
79+
80+
# forward
81+
with torch.inference_mode():
82+
encoder.forward(sample)
83+
84+
def test_depth(self):
85+
cases = [
86+
(encoder_name, depth)
87+
for encoder_name in self.encoder_names
88+
for depth in self.depth_to_test
89+
]
90+
91+
for encoder_name, depth in cases:
92+
patch_size = self._get_model_expected_input_shape(encoder_name)
93+
sample = self._get_sample(height = patch_size, width = patch_size).to(default_device)
94+
with self.subTest(encoder_name=encoder_name, depth=depth):
95+
encoder = smp.encoders.get_encoder(
96+
encoder_name,
97+
in_channels=self.default_num_channels,
98+
encoder_weights=None,
99+
depth=depth,
100+
output_stride = None,
101+
**self.default_encoder_kwargs
102+
).to(default_device)
103+
encoder.eval()
104+
105+
# forward
106+
with torch.inference_mode():
107+
features = encoder.forward(sample)
108+
109+
# check number of features
110+
self.assertEqual(
111+
len(features[0]),
112+
depth,
113+
f"Encoder `{encoder_name}` should have {depth} output feature maps, but has {len(features[0])}",
114+
)
115+
116+
# check feature strides
117+
height_strides, width_strides = self.get_features_output_strides(
118+
sample, features[0]
119+
)
120+
121+
timm_encoder_name = encoder_name[3 : ]
122+
encoder_out_indices = encoder.out_indices
123+
timm_model_feature_info = timm.create_model(model_name = timm_encoder_name).feature_info
124+
feature_info_obj = timm.models.FeatureInfo(feature_info = timm_model_feature_info,out_indices = encoder_out_indices)
125+
self.output_strides = feature_info_obj.reduction()
126+
127+
self.assertEqual(
128+
height_strides,
129+
self.output_strides[: depth],
130+
f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth]}, but has {height_strides}",
131+
)
132+
self.assertEqual(
133+
width_strides,
134+
self.output_strides[: depth],
135+
f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth]}, but has {width_strides}",
136+
)
137+
138+
# check encoder output stride property
139+
self.assertEqual(
140+
encoder.output_stride,
141+
self.output_strides[depth - 1],
142+
f"Encoder `{encoder_name}` last feature map should have output stride {self.output_strides[depth - 1]}, but has {encoder.output_stride}",
143+
)
144+
145+
# check out channels also have proper length
146+
self.assertEqual(
147+
len(encoder.out_channels),
148+
depth,
149+
f"Encoder `{encoder_name}` should have {depth} out_channels, but has {len(encoder.out_channels)}",
150+
)
151+
152+
def test_invalid_depth(self):
153+
with self.assertRaises(ValueError):
154+
smp.encoders.get_encoder(self.encoder_names[0], depth=5,output_stride = None)
155+
with self.assertRaises(ValueError):
156+
smp.encoders.get_encoder(self.encoder_names[0], depth=0,output_stride = None)
157+
158+
def test_dilated(self):
159+
160+
161+
cases = [
162+
(encoder_name, stride)
163+
for encoder_name in self.encoder_names
164+
for stride in self.strides_to_test
165+
]
166+
167+
# special case for encoders that do not support dilated model
168+
# just check proper error is raised
169+
if not self.supports_dilated:
170+
with self.assertRaises(ValueError, msg="Dilated mode not supported, set output stride to None"):
171+
encoder_name, stride = cases[0]
172+
patch_size = self._get_model_expected_input_shape(encoder_name)
173+
sample = self._get_sample(height = patch_size, width = patch_size).to(default_device)
174+
encoder = smp.encoders.get_encoder(
175+
encoder_name,
176+
in_channels=self.default_num_channels,
177+
encoder_weights=None,
178+
output_stride=stride,
179+
depth = self.default_depth,
180+
**self.default_encoder_kwargs,
181+
).to(default_device)
182+
return
183+
184+
for encoder_name, stride in cases:
185+
with self.subTest(encoder_name=encoder_name, stride=stride):
186+
encoder = smp.encoders.get_encoder(
187+
encoder_name,
188+
in_channels=self.default_num_channels,
189+
encoder_weights=None,
190+
output_stride=stride,
191+
depth = self.default_depth,
192+
**self.default_encoder_kwargs,
193+
).to(default_device)
194+
encoder.eval()
195+
196+
# forward
197+
with torch.inference_mode():
198+
features = encoder.forward(sample)
199+
200+
height_strides, width_strides = self.get_features_output_strides(
201+
sample, features[0]
202+
)
203+
expected_height_strides = [min(stride, s) for s in height_strides]
204+
expected_width_strides = [min(stride, s) for s in width_strides]
205+
206+
self.assertEqual(
207+
height_strides,
208+
expected_height_strides,
209+
f"Encoder `{encoder_name}` should have height output strides {expected_height_strides}, but has {height_strides}",
210+
)
211+
self.assertEqual(
212+
width_strides,
213+
expected_width_strides,
214+
f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}",
215+
)
216+
217+
@pytest.mark.compile
218+
def test_compile(self):
219+
if not check_run_test_on_diff_or_main(self.files_for_diff):
220+
self.skipTest("No diff and not on `main`.")
221+
222+
encoder = self.get_tiny_encoder()
223+
encoder = encoder.eval().to(default_device)
224+
225+
sample = self._get_sample(height = self.tiny_encoder_patch_size, width = self.tiny_encoder_patch_size).to(default_device)
226+
227+
torch.compiler.reset()
228+
compiled_encoder = torch.compile(
229+
encoder, fullgraph=True, dynamic=True, backend="eager"
230+
)
231+
232+
if encoder._is_torch_compilable:
233+
compiled_encoder(sample)
234+
else:
235+
with self.assertRaises(Exception):
236+
compiled_encoder(sample)
237+
238+
@pytest.mark.torch_export
239+
@requires_torch_greater_or_equal("2.4.0")
240+
def test_torch_export(self):
241+
if not check_run_test_on_diff_or_main(self.files_for_diff):
242+
self.skipTest("No diff and not on `main`.")
243+
244+
sample = self._get_sample(height = self.tiny_encoder_patch_size, width = self.tiny_encoder_patch_size).to(default_device)
245+
246+
encoder = self.get_tiny_encoder()
247+
encoder = encoder.eval().to(default_device)
248+
249+
if not encoder._is_torch_exportable:
250+
with self.assertRaises(Exception):
251+
exported_encoder = torch.export.export(
252+
encoder,
253+
args=(sample,),
254+
strict=True,
255+
)
256+
return
257+
258+
exported_encoder = torch.export.export(
259+
encoder,
260+
args=(sample,),
261+
strict=True,
262+
)
263+
264+
with torch.inference_mode():
265+
eager_output = encoder(sample)
266+
exported_output = exported_encoder.module().forward(sample)
267+
268+
for eager_feature, exported_feature in zip(eager_output, exported_output):
269+
torch.testing.assert_close(eager_feature, exported_feature)
270+
271+
@pytest.mark.torch_script
272+
def test_torch_script(self):
273+
sample = self._get_sample(height = self.tiny_encoder_patch_size, width = self.tiny_encoder_patch_size).to(default_device)
274+
275+
encoder = self.get_tiny_encoder()
276+
encoder = encoder.eval().to(default_device)
277+
278+
if not encoder._is_torch_scriptable:
279+
with self.assertRaises(RuntimeError, msg="not torch scriptable"):
280+
scripted_encoder = torch.jit.script(encoder)
281+
return
282+
283+
scripted_encoder = torch.jit.script(encoder)
284+
285+
with torch.inference_mode():
286+
eager_output = encoder(sample)
287+
scripted_output = scripted_encoder(sample)
288+
289+
for eager_feature, scripted_feature in zip(eager_output, scripted_output):
290+
torch.testing.assert_close(eager_feature, scripted_feature)
291+
292+
293+
294+
295+
296+

0 commit comments

Comments
 (0)