Skip to content

Commit 4eb6ec3

Browse files
Tests update
1 parent 8d3ed4f commit 4eb6ec3

File tree

4 files changed

+60
-57
lines changed

4 files changed

+60
-57
lines changed

segmentation_models_pytorch/decoders/dpt/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class DPT(SegmentationModel):
7171
7272
"""
7373

74-
_is_torch_scriptable = False
75-
_is_torch_compilable = False
74+
_is_torch_scriptable = True
75+
_is_torch_compilable = True
7676
requires_divisible_input_shape = True
7777

7878
@supports_config_loading

segmentation_models_pytorch/encoders/timm_vit.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def __init__(
9292
f"{self.__class__.__name__} depth should be in range [1, 4], got {depth}"
9393
)
9494

95+
# Output stride validation needed for smp encoder test consistency
96+
output_stride = kwargs.pop("output_stride", None)
97+
if output_stride is not None:
98+
raise ValueError("Dilated mode not supported, set output stride to None")
99+
95100
if isinstance(output_indices, (list, tuple)) and len(output_indices) != depth:
96101
raise ValueError(
97102
f"Length of output indices for feature extraction should be equal to the depth of the encoder "
@@ -185,7 +190,7 @@ def forward(
185190
if self.has_prefix_tokens:
186191
features, prefix_tokens = self._forward_with_prefix_tokens(x)
187192
else:
188-
features = self._forward_without_cls_token(x)
193+
features = self._forward_without_prefix_tokens(x)
189194
prefix_tokens = [None] * len(features)
190195

191196
return features, prefix_tokens

tests/encoders/test_timm_vit_encoders.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from tests.encoders import base
22
import timm
33
import torch
4-
import segmentation_models_pytorch as smp
54
import pytest
5+
from segmentation_models_pytorch.encoders import TimmViTEncoder
6+
from segmentation_models_pytorch.encoders.timm_vit import sample_block_indices_uniformly
67

78
from tests.utils import (
89
default_device,
@@ -11,20 +12,14 @@
1112
requires_timm_greater_or_equal,
1213
)
1314

14-
timm_vit_encoders = [
15-
"tu-vit_tiny_patch16_224",
16-
"tu-vit_small_patch32_224",
17-
"tu-vit_base_patch32_384",
18-
"tu-vit_base_patch16_gap_224",
19-
"tu-vit_medium_patch16_reg4_gap_256",
20-
"tu-vit_so150m2_patch16_reg1_gap_256",
21-
"tu-vit_medium_patch16_gap_240",
22-
]
15+
timm_vit_encoders = ["vit_tiny_patch16_224"]
2316

2417

2518
class TestTimmViTEncoders(base.BaseEncoderTester):
2619
encoder_names = timm_vit_encoders
2720
tiny_encoder_patch_size = 224
21+
default_height = 224
22+
default_width = 224
2823

2924
files_for_diff = ["encoders/dpt.py"]
3025

@@ -35,14 +30,10 @@ class TestTimmViTEncoders(base.BaseEncoderTester):
3530

3631
depth_to_test = [2, 3, 4]
3732

38-
default_encoder_kwargs = {"use_vit_encoder": True}
39-
40-
def _get_model_expected_input_shape(self, encoder_name: str) -> int:
41-
patch_size_str = encoder_name[-3:]
42-
return int(patch_size_str)
33+
default_encoder_kwargs = {"pretrained": False}
4334

4435
def get_tiny_encoder(self):
45-
return smp.encoders.get_encoder(
36+
return TimmViTEncoder(
4637
self.encoder_names[0],
4738
encoder_weights=None,
4839
output_stride=None,
@@ -55,13 +46,10 @@ def get_tiny_encoder(self):
5546
@requires_timm_greater_or_equal("1.0.15")
5647
def test_forward_backward(self):
5748
for encoder_name in self.encoder_names:
58-
patch_size = self._get_model_expected_input_shape(encoder_name)
59-
sample = self._get_sample(height=patch_size, width=patch_size).to(
60-
default_device
61-
)
49+
sample = self._get_sample().to(default_device)
6250
with self.subTest(encoder_name=encoder_name):
6351
# init encoder
64-
encoder = smp.encoders.get_encoder(
52+
encoder = TimmViTEncoder(
6553
encoder_name,
6654
in_channels=3,
6755
encoder_weights=None,
@@ -90,13 +78,10 @@ def test_in_channels(self):
9078
]
9179

9280
for encoder_name, in_channels in cases:
93-
patch_size = self._get_model_expected_input_shape(encoder_name)
94-
sample = self._get_sample(
95-
height=patch_size, width=patch_size, num_channels=in_channels
96-
).to(default_device)
81+
sample = self._get_sample(num_channels=in_channels).to(default_device)
9782

9883
with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
99-
encoder = smp.encoders.get_encoder(
84+
encoder = TimmViTEncoder(
10085
encoder_name,
10186
in_channels=in_channels,
10287
encoder_weights=None,
@@ -119,12 +104,9 @@ def test_depth(self):
119104
]
120105

121106
for encoder_name, depth in cases:
122-
patch_size = self._get_model_expected_input_shape(encoder_name)
123-
sample = self._get_sample(height=patch_size, width=patch_size).to(
124-
default_device
125-
)
107+
sample = self._get_sample().to(default_device)
126108
with self.subTest(encoder_name=encoder_name, depth=depth):
127-
encoder = smp.encoders.get_encoder(
109+
encoder = TimmViTEncoder(
128110
encoder_name,
129111
in_channels=self.default_num_channels,
130112
encoder_weights=None,
@@ -150,10 +132,9 @@ def test_depth(self):
150132
sample, features
151133
)
152134

153-
timm_encoder_name = encoder_name[3:]
154-
encoder_out_indices = encoder.out_indices
135+
encoder_out_indices = sample_block_indices_uniformly(depth, 12)
155136
timm_model_feature_info = timm.create_model(
156-
model_name=timm_encoder_name
137+
model_name=encoder_name
157138
).feature_info
158139
feature_info_obj = timm.models.FeatureInfo(
159140
feature_info=timm_model_feature_info,
@@ -189,35 +170,56 @@ def test_depth(self):
189170
@requires_timm_greater_or_equal("1.0.15")
190171
def test_invalid_depth(self):
191172
with self.assertRaises(ValueError):
192-
smp.encoders.get_encoder(self.encoder_names[0], depth=5, output_stride=None)
173+
TimmViTEncoder(
174+
self.encoder_names[0],
175+
depth=5,
176+
output_stride=None,
177+
**self.default_encoder_kwargs,
178+
)
193179
with self.assertRaises(ValueError):
194-
smp.encoders.get_encoder(self.encoder_names[0], depth=0, output_stride=None)
180+
TimmViTEncoder(
181+
self.encoder_names[0],
182+
depth=0,
183+
output_stride=None,
184+
**self.default_encoder_kwargs,
185+
)
195186

196187
@requires_timm_greater_or_equal("1.0.15")
197188
def test_invalid_out_indices(self):
198189
with self.assertRaises(ValueError):
199-
smp.encoders.get_encoder(
200-
self.encoder_names[0], output_stride=None, out_indices=-1
190+
TimmViTEncoder(
191+
self.encoder_names[0],
192+
output_stride=None,
193+
output_indices=-25,
194+
**self.default_encoder_kwargs,
201195
)
202196

203197
with self.assertRaises(ValueError):
204-
smp.encoders.get_encoder(
205-
self.encoder_names[0], output_stride=None, out_indices=[1, 2, 25]
198+
TimmViTEncoder(
199+
self.encoder_names[0],
200+
output_stride=None,
201+
output_indices=[1, 2, 25],
202+
**self.default_encoder_kwargs,
206203
)
207204

208205
@requires_timm_greater_or_equal("1.0.15")
209206
def test_invalid_out_indices_length(self):
210207
with self.assertRaises(ValueError):
211-
smp.encoders.get_encoder(
212-
self.encoder_names[0], output_stride=None, out_indices=2, depth=2
208+
TimmViTEncoder(
209+
self.encoder_names[0],
210+
output_stride=None,
211+
output_indices=2,
212+
depth=2,
213+
**self.default_encoder_kwargs,
213214
)
214215

215216
with self.assertRaises(ValueError):
216-
smp.encoders.get_encoder(
217+
TimmViTEncoder(
217218
self.encoder_names[0],
218219
output_stride=None,
219-
out_indices=[0, 1, 2, 3, 4],
220+
output_indices=[0, 1, 2, 3, 4],
220221
depth=4,
222+
**self.default_encoder_kwargs,
221223
)
222224

223225
@requires_timm_greater_or_equal("1.0.15")
@@ -235,23 +237,19 @@ def test_dilated(self):
235237
ValueError, msg="Dilated mode not supported, set output stride to None"
236238
):
237239
encoder_name, stride = cases[0]
238-
patch_size = self._get_model_expected_input_shape(encoder_name)
239-
sample = self._get_sample(height=patch_size, width=patch_size).to(
240-
default_device
241-
)
242-
encoder = smp.encoders.get_encoder(
240+
sample = self._get_sample().to(default_device)
241+
encoder = TimmViTEncoder(
243242
encoder_name,
244243
in_channels=self.default_num_channels,
245244
encoder_weights=None,
246245
output_stride=stride,
247246
depth=self.default_depth,
248-
**self.default_encoder_kwargs,
249247
).to(default_device)
250248
return
251249

252250
for encoder_name, stride in cases:
253251
with self.subTest(encoder_name=encoder_name, stride=stride):
254-
encoder = smp.encoders.get_encoder(
252+
encoder = TimmViTEncoder(
255253
encoder_name,
256254
in_channels=self.default_num_channels,
257255
encoder_weights=None,

tests/models/test_dpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212

1313
class TestDPTModel(base.BaseModelTester):
14-
test_encoder_name = "tu-vit_large_patch16_384"
14+
test_encoder_name = "tu-vit_tiny_patch16_224"
1515
files_for_diff = [r"decoders/dpt/", r"base/"]
1616

17-
default_height = 384
18-
default_width = 384
17+
default_height = 224
18+
default_width = 224
1919

2020
# should be overriden
2121
test_model_type = "dpt"

0 commit comments

Comments
 (0)