Skip to content

Commit c5bc356

Browse files
author
Rustem Galiullin
committed
implement skip connections for sam vit encoder
1 parent e5c4bc4 commit c5bc356

File tree

2 files changed

+92
-36
lines changed

2 files changed

+92
-36
lines changed

segmentation_models_pytorch/encoders/sam.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66
from segment_anything.modeling import ImageEncoderViT
7+
from torch import nn
8+
from segment_anything.modeling.common import LayerNorm2d
79

810
from segmentation_models_pytorch.encoders._base import EncoderMixin
911

@@ -16,15 +18,55 @@ def __init__(self, **kwargs):
1618
super().__init__(**kwargs)
1719
self._out_chans = kwargs.get("out_chans", 256)
1820
self._patch_size = kwargs.get("patch_size", 16)
21+
self._embed_dim = kwargs.get("embed_dim", 768)
1922
self._validate()
23+
self.intermediate_necks = nn.ModuleList(
24+
[self.init_neck(self._embed_dim, out_chan) for out_chan in self.out_channels[:-1]]
25+
)
26+
27+
@staticmethod
28+
def init_neck(embed_dim: int, out_chans: int) -> nn.Module:
29+
# Use similar neck as in ImageEncoderViT
30+
return nn.Sequential(
31+
nn.Conv2d(
32+
embed_dim,
33+
out_chans,
34+
kernel_size=1,
35+
bias=False,
36+
),
37+
LayerNorm2d(out_chans),
38+
nn.Conv2d(
39+
out_chans,
40+
out_chans,
41+
kernel_size=3,
42+
padding=1,
43+
bias=False,
44+
),
45+
LayerNorm2d(out_chans),
46+
)
47+
48+
@staticmethod
49+
def neck_forward(neck: nn.Module, x: torch.Tensor, scale_factor: float = 1) -> torch.Tensor:
50+
x = x.permute(0, 3, 1, 2)
51+
if scale_factor != 1.0:
52+
x = nn.functional.interpolate(x, scale_factor=scale_factor, mode="bilinear")
53+
return neck(x)
54+
55+
def requires_grad_(self, requires_grad: bool = True):
56+
# Keep the intermediate necks trainable
57+
for param in self.parameters():
58+
param.requires_grad_(requires_grad)
59+
for param in self.intermediate_necks.parameters():
60+
param.requires_grad_(True)
61+
return self
2062

2163
@property
2264
def output_stride(self):
2365
return 32
2466

25-
def _get_scale_factor(self) -> float:
26-
"""Input image will be downscale by this factor"""
27-
return int(math.log(self._patch_size, 2))
67+
@property
68+
def out_channels(self):
69+
return [self._out_chans // (2**i) for i in range(self._encoder_depth + 1)][::-1]
2870

2971
def _validate(self):
3072
# check vit depth
@@ -39,15 +81,30 @@ def _validate(self):
3981
"It is recommended to set encoder depth=4 with default vit patch_size=16."
4082
)
4183

42-
@property
43-
def out_channels(self):
44-
# Fill up with leading zeros to be used in Unet
45-
scale_factor = self._get_scale_factor()
46-
return [0] * scale_factor + [self._out_chans]
84+
def _get_scale_factor(self) -> float:
85+
"""Input image will be downscale by this factor"""
86+
return int(math.log(self._patch_size, 2))
4787

4888
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
49-
# Return a list of tensors to match other encoders
50-
return [x, super().forward(x)]
89+
x = self.patch_embed(x)
90+
if self.pos_embed is not None:
91+
x = x + self.pos_embed
92+
93+
features = []
94+
skip_steps = self._vit_depth // self._encoder_depth
95+
scale_factor = self._get_scale_factor()
96+
for i, blk in enumerate(self.blocks):
97+
x = blk(x)
98+
if i % skip_steps == 0:
99+
# Double spatial dimension and halve number of channels
100+
neck = self.intermediate_necks[i // skip_steps]
101+
features.append(self.neck_forward(neck, x, scale_factor=2**scale_factor))
102+
scale_factor -= 1
103+
104+
x = self.neck(x.permute(0, 3, 1, 2))
105+
features.append(x)
106+
107+
return features
51108

52109
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None:
53110
# Exclude mask_decoder and prompt encoder weights
@@ -58,6 +115,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) ->
58115
if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder")
59116
}
60117
missing, unused = super().load_state_dict(state_dict, strict=False)
118+
missing = list(filter(lambda x: not x.startswith("intermediate_necks"), missing))
61119
if len(missing) + len(unused) > 0:
62120
n_loaded = len(state_dict) - len(missing) - len(unused)
63121
warnings.warn(

tests/test_sam.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,35 @@
1313
def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth):
1414
encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth)
1515
assert encoder.output_stride == 32
16+
assert encoder.out_channels == [256 // (2**i) for i in range(depth + 1)][::-1]
1617

1718
sample = torch.ones(1, 3, img_size, img_size)
1819
with torch.no_grad():
1920
out = encoder(sample)
2021

21-
expected_patches = img_size // patch_size
22-
assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches])
22+
assert len(out) == depth + 1
23+
24+
expected_spatial_size = img_size // patch_size
25+
expected_chans = 256
26+
for i in range(1, len(out)):
27+
assert out[-i].size() == torch.Size([1, expected_chans, expected_spatial_size, expected_spatial_size])
28+
expected_spatial_size *= 2
29+
expected_chans //= 2
30+
31+
32+
def test_sam_encoder_trainable():
33+
encoder = get_encoder("sam-vit_b", depth=4)
34+
35+
encoder.requires_grad_(False)
36+
for name, param in encoder.named_parameters():
37+
if name.startswith("intermediate_necks"):
38+
assert param.requires_grad
39+
else:
40+
assert not param.requires_grad
41+
42+
encoder.requires_grad_(True)
43+
for param in encoder.parameters():
44+
assert param.requires_grad
2345

2446

2547
def test_sam_encoder_validation_error():
@@ -29,25 +51,6 @@ def test_sam_encoder_validation_error():
2951
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6)
3052

3153

32-
@pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration")
33-
@pytest.mark.parametrize("decoder_multiclass_output", [True, False])
34-
@pytest.mark.parametrize("n_classes", [1, 3])
35-
def test_sam(decoder_multiclass_output, n_classes):
36-
model = smp.SAM(
37-
"sam-vit_b",
38-
encoder_weights=None,
39-
weights=None,
40-
image_size=64,
41-
decoder_multimask_output=decoder_multiclass_output,
42-
classes=n_classes,
43-
)
44-
sample = get_sample(smp.SAM)
45-
model.eval()
46-
47-
_test_forward(model, sample, test_shape=True)
48-
_test_forward_backward(model, sample, test_shape=True)
49-
50-
5154
@pytest.mark.parametrize("model_class", [smp.Unet])
5255
@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)])
5356
def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth):
@@ -62,11 +65,6 @@ def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth):
6265
_test_forward_backward(model, smp, test_shape=True)
6366

6467

65-
@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
66-
def test_sam_weights():
67-
smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b")
68-
69-
7068
@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
7169
def test_sam_encoder_weights():
7270
smp.create_model(

0 commit comments

Comments
 (0)