Skip to content

Commit 5edc0ee

Browse files
author
Rustem Galiullin
committed
use vit_depth to control sam vit depth
1 parent e968719 commit 5edc0ee

File tree

4 files changed

+46
-48
lines changed

4 files changed

+46
-48
lines changed

segmentation_models_pytorch/decoders/unet/model.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,20 @@ def __init__(
6565
classes: int = 1,
6666
activation: Optional[Union[str, callable]] = None,
6767
aux_params: Optional[dict] = None,
68-
encoder_kwargs: Optional[dict] = None,
6968
):
7069
super().__init__()
7170

72-
# if sam encoder, make sure to make num_hidden_skips is set
73-
if encoder_name.startswith("sam-"):
74-
encoder_kwargs = encoder_kwargs if encoder_kwargs is not None else {}
75-
encoder_kwargs.update({"num_hidden_skips": len(decoder_channels)})
76-
n_decoder_blocks = len(decoder_channels)
77-
else:
78-
n_decoder_blocks = encoder_depth
79-
8071
self.encoder = get_encoder(
8172
encoder_name,
8273
in_channels=in_channels,
8374
depth=encoder_depth,
8475
weights=encoder_weights,
85-
**encoder_kwargs if encoder_kwargs is not None else {},
8676
)
8777

8878
self.decoder = UnetDecoder(
8979
encoder_channels=self.encoder.out_channels,
9080
decoder_channels=decoder_channels,
91-
n_blocks=n_decoder_blocks,
81+
n_blocks=encoder_depth,
9282
use_batchnorm=decoder_use_batchnorm,
9383
center=True if encoder_name.startswith("vgg") else False,
9484
attention_type=decoder_attention_type,

segmentation_models_pytorch/encoders/__init__.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,8 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
9797
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
9898

9999
params = encoders[name]["params"]
100-
if name.startswith("sam-"):
101-
params.update(**kwargs)
102-
params.update(dict(name=name[4:]))
103-
if depth is not None:
104-
params.update(depth=depth)
105-
else:
106-
params.update(depth=depth)
100+
params.update(depth=depth)
101+
params.update(kwargs)
107102
encoder = Encoder(**params)
108103

109104
if weights is not None:

segmentation_models_pytorch/encoders/sam.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,41 @@
99

1010

1111
class SamVitEncoder(EncoderMixin, ImageEncoderViT):
12-
def __init__(self, name: str, **kwargs):
13-
patch_size = kwargs.get("patch_size", 16)
14-
n_skips = kwargs.pop("num_hidden_skips", int(self._get_scale_factor(patch_size)))
12+
def __init__(self, **kwargs):
13+
self._vit_depth = kwargs.pop("vit_depth")
14+
self._encoder_depth = kwargs.get("depth", 5)
15+
kwargs.update({"depth": self._vit_depth})
1516
super().__init__(**kwargs)
16-
self._name = name
17-
self._depth = kwargs["depth"]
1817
self._out_chans = kwargs.get("out_chans", 256)
19-
self._num_skips = n_skips
20-
self._validate_output(patch_size)
18+
self._patch_size = kwargs.get("patch_size", 16)
19+
self._validate()
2120

22-
@staticmethod
23-
def _get_scale_factor(patch_size: int) -> float:
21+
@property
22+
def output_stride(self):
23+
return 32
24+
25+
def _get_scale_factor(self) -> float:
2426
"""Input image will be downscale by this factor"""
25-
return math.log(patch_size, 2)
27+
return int(math.log(self._patch_size, 2))
2628

27-
def _validate_output(self, patch_size: int):
28-
scale_factor = self._get_scale_factor(patch_size)
29-
if scale_factor != self._num_skips:
29+
def _validate(self):
30+
# check vit depth
31+
if self._vit_depth not in [12, 24, 32]:
32+
raise ValueError(f"vit_depth must be one of [12, 24, 32], got {self._vit_depth}")
33+
# check output
34+
scale_factor = self._get_scale_factor()
35+
if scale_factor != self._encoder_depth:
3036
raise ValueError(
31-
f"With {patch_size=} and {self._num_skips} skip connection layers, "
32-
"spatial dimensions of model output will not match input spatial dimensions"
37+
f"With patch_size={self._patch_size} and depth={self._encoder_depth}, "
38+
"spatial dimensions of model output will not match input spatial dimensions. "
39+
"It is recommended to set encoder depth=4 with default vit patch_size=16."
3340
)
3441

3542
@property
3643
def out_channels(self):
3744
# Fill up with leading zeros to be used in Unet
38-
return [0] * self._num_skips + [self._out_chans]
45+
scale_factor = self._get_scale_factor()
46+
return [0] * scale_factor + [self._out_chans]
3947

4048
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
4149
# Return a list of tensors to match other encoders
@@ -66,7 +74,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) ->
6674
},
6775
"params": dict(
6876
embed_dim=1280,
69-
depth=32,
77+
vit_depth=32,
7078
num_heads=16,
7179
global_attn_indexes=[7, 15, 23, 31],
7280
),
@@ -78,7 +86,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) ->
7886
},
7987
"params": dict(
8088
embed_dim=1024,
81-
depth=24,
89+
vit_depth=24,
8290
num_heads=16,
8391
global_attn_indexes=[5, 11, 17, 23],
8492
),
@@ -90,7 +98,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) ->
9098
},
9199
"params": dict(
92100
embed_dim=768,
93-
depth=12,
101+
vit_depth=12,
94102
num_heads=12,
95103
global_attn_indexes=[2, 5, 8, 11],
96104
),

tests/test_sam.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"])
1010
@pytest.mark.parametrize("img_size", [64, 128])
11-
@pytest.mark.parametrize("patch_size", [8, 16])
12-
@pytest.mark.parametrize("depth", [6, 24, None])
13-
def test_sam_encoder(encoder_name, img_size, patch_size, depth):
14-
encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth)
15-
assert encoder._name == encoder_name[4:]
11+
@pytest.mark.parametrize("patch_size,depth", [(8, 3), (16, 4)])
12+
@pytest.mark.parametrize("vit_depth", [12, 24])
13+
def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth):
14+
encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth)
1615
assert encoder.output_stride == 32
1716

1817
sample = torch.ones(1, 3, img_size, img_size)
@@ -23,6 +22,13 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth):
2322
assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches])
2423

2524

25+
def test_sam_encoder_validation_error():
26+
with pytest.raises(ValueError):
27+
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=5, vit_depth=12)
28+
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=None)
29+
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6)
30+
31+
2632
@pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration")
2733
@pytest.mark.parametrize("decoder_multiclass_output", [True, False])
2834
@pytest.mark.parametrize("n_classes", [1, 3])
@@ -43,14 +49,13 @@ def test_sam(decoder_multiclass_output, n_classes):
4349

4450

4551
@pytest.mark.parametrize("model_class", [smp.Unet])
46-
@pytest.mark.parametrize("decoder_channels,patch_size", [([64, 32, 16, 8], 16), ([64, 32, 16], 8)])
47-
def test_sam_as_encoder_only(model_class, decoder_channels, patch_size):
48-
img_size = 64
52+
@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)])
53+
def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth):
54+
img_size = 1024
4955
model = model_class(
5056
"sam-vit_b",
5157
encoder_weights=None,
52-
encoder_depth=3,
53-
encoder_kwargs=dict(img_size=img_size, out_chans=decoder_channels[0], patch_size=patch_size),
58+
encoder_depth=encoder_depth,
5459
decoder_channels=decoder_channels,
5560
)
5661
smp = torch.ones(1, 3, img_size, img_size)
@@ -65,5 +70,5 @@ def test_sam_weights():
6570
@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
6671
def test_sam_encoder_weights():
6772
smp.create_model(
68-
"unet", encoder_name="sam-vit_b", encoder_weights="sa-1b", encoder_depth=12, decoder_channels=[64, 32, 16, 8]
73+
"unet", encoder_name="sam-vit_b", encoder_depth=4, encoder_weights="sa-1b", decoder_channels=[64, 32, 16, 8]
6974
)

0 commit comments

Comments
 (0)