Skip to content

Commit d121fec

Browse files
committed
Add depth validation
1 parent 5bbb1db commit d121fec

16 files changed

+84
-1
lines changed

segmentation_models_pytorch/encoders/densenet.py

+6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232

3333
class DenseNetEncoder(DenseNet, EncoderMixin):
3434
def __init__(self, out_channels, depth=5, output_stride=32, **kwargs):
35+
if depth > 5 or depth < 1:
36+
raise ValueError(
37+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
38+
)
39+
3540
super().__init__(**kwargs)
41+
3642
self._depth = depth
3743
self._in_channels = 3
3844
self._out_channels = out_channels

segmentation_models_pytorch/encoders/dpn.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def __init__(
4444
output_stride: int = 32,
4545
**kwargs,
4646
):
47+
if depth > 5 or depth < 1:
48+
raise ValueError(
49+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
50+
)
51+
4752
super().__init__(**kwargs)
4853
self._stage_idxs = stage_idxs
4954
self._depth = depth

segmentation_models_pytorch/encoders/efficientnet.py

+5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def __init__(
4343
depth: int = 5,
4444
output_stride: int = 32,
4545
):
46+
if depth > 5 or depth < 1:
47+
raise ValueError(
48+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
49+
)
50+
4651
blocks_args, global_params = get_model_params(model_name, override_params=None)
4752
super().__init__(blocks_args, global_params)
4853

segmentation_models_pytorch/encoders/inceptionresnetv2.py

+5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def __init__(
3939
output_stride: int = 32,
4040
**kwargs,
4141
):
42+
if depth > 5 or depth < 1:
43+
raise ValueError(
44+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
45+
)
46+
4247
super().__init__(**kwargs)
4348

4449
self._depth = depth

segmentation_models_pytorch/encoders/inceptionv4.py

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def __init__(
4040
output_stride: int = 32,
4141
**kwargs,
4242
):
43+
if depth > 5 or depth < 1:
44+
raise ValueError(
45+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
46+
)
4347
super().__init__(**kwargs)
4448

4549
self._depth = depth

segmentation_models_pytorch/encoders/mix_transformer.py

+4
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,10 @@ class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin):
529529
def __init__(
530530
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
531531
):
532+
if depth > 5 or depth < 1:
533+
raise ValueError(
534+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
535+
)
532536
super().__init__(**kwargs)
533537

534538
self._depth = depth

segmentation_models_pytorch/encoders/mobilenet.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
3434
def __init__(
3535
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
3636
):
37+
if depth > 5 or depth < 1:
38+
raise ValueError(
39+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
40+
)
3741
super().__init__(**kwargs)
3842

3943
self._depth = depth

segmentation_models_pytorch/encoders/mobileone.py

+5
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def __init__(
319319
:param use_se: Whether to use SE-ReLU activations.
320320
:param num_conv_branches: Number of linear conv branches.
321321
"""
322+
if depth > 5 or depth < 1:
323+
raise ValueError(
324+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
325+
)
326+
322327
super().__init__()
323328

324329
assert len(width_multipliers) == 4

segmentation_models_pytorch/encoders/resnet.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class ResNetEncoder(ResNet, EncoderMixin):
3838
def __init__(
3939
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
4040
):
41+
if depth > 5 or depth < 1:
42+
raise ValueError(
43+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
44+
)
4145
super().__init__(**kwargs)
4246

4347
self._depth = depth

segmentation_models_pytorch/encoders/senet.py

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __init__(
4343
output_stride: int = 32,
4444
**kwargs,
4545
):
46+
if depth > 5 or depth < 1:
47+
raise ValueError(
48+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
49+
)
4650
super().__init__(**kwargs)
4751

4852
self._depth = depth

segmentation_models_pytorch/encoders/timm_efficientnet.py

+4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def __init__(
105105
output_stride: int = 32,
106106
**kwargs,
107107
):
108+
if depth > 5 or depth < 1:
109+
raise ValueError(
110+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
111+
)
108112
super().__init__(**kwargs)
109113

110114
self._stage_idxs = stage_idxs

segmentation_models_pytorch/encoders/timm_sknet.py

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def __init__(
1414
output_stride: int = 32,
1515
**kwargs,
1616
):
17+
if depth > 5 or depth < 1:
18+
raise ValueError(
19+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
20+
)
1721
super().__init__(**kwargs)
1822

1923
self._depth = depth

segmentation_models_pytorch/encoders/timm_universal.py

+7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ def __init__(
6868
output_stride (int): Desired output stride (default: 32).
6969
**kwargs: Additional arguments passed to `timm.create_model`.
7070
"""
71+
# At the moment we do not support models with more than 5 stages,
72+
# but can be reconfigured in the future.
73+
if depth > 5 or depth < 1:
74+
raise ValueError(
75+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
76+
)
77+
7178
super().__init__()
7279
self.name = name
7380

segmentation_models_pytorch/encoders/vgg.py

+4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def __init__(
5353
output_stride: int = 32,
5454
**kwargs,
5555
):
56+
if depth > 5 or depth < 1:
57+
raise ValueError(
58+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
59+
)
5660
super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs)
5761

5862
self._depth = depth

segmentation_models_pytorch/encoders/xception.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
from typing import List
12
from pretrainedmodels.models.xception import Xception
23

34
from ._base import EncoderMixin
45

56

67
class XceptionEncoder(Xception, EncoderMixin):
7-
def __init__(self, out_channels, *args, depth=5, output_stride=32, **kwargs):
8+
def __init__(
9+
self,
10+
out_channels: List[int],
11+
*args,
12+
depth: int = 5,
13+
output_stride: int = 32,
14+
**kwargs,
15+
):
16+
if depth > 5 or depth < 1:
17+
raise ValueError(
18+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
19+
)
820
super().__init__(*args, **kwargs)
921

1022
self._depth = depth

tests/encoders/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def test_depth(self):
149149
f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}",
150150
)
151151

152+
def test_invalid_depth(self):
153+
with self.assertRaises(ValueError):
154+
smp.encoders.get_encoder(self.encoder_names[0], depth=6)
155+
with self.assertRaises(ValueError):
156+
smp.encoders.get_encoder(self.encoder_names[0], depth=0)
157+
152158
def test_dilated(self):
153159
sample = self._get_sample().to(default_device)
154160

0 commit comments

Comments
 (0)