Skip to content

Commit 9731e8f

Browse files
authored
Merge pull request #1 from Rusteam/sam-ddp
Make sam changes to enable DDP training
2 parents 4d1144e + 9c93eb4 commit 9731e8f

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

segmentation_models_pytorch/decoders/sam/model.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder
6+
from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom
7+
from torch import nn
68
from torch.nn import functional as F
79
from torch.utils import model_zoo
810

@@ -86,13 +88,12 @@ def __init__(
8688
out_chans=decoder_channels,
8789
)
8890

91+
# this params are used instead of prompt_encoder
8992
image_embedding_size = image_size // vit_patch_size
90-
self.prompt_encoder = PromptEncoder(
91-
embed_dim=decoder_channels,
92-
image_embedding_size=(image_embedding_size, image_embedding_size),
93-
input_image_size=(image_size, image_size),
94-
mask_in_chans=16,
95-
)
93+
self.embed_dim = decoder_channels
94+
self.image_embedding_size = (image_embedding_size, image_embedding_size)
95+
self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2)
96+
self.no_mask_embed = nn.Embedding(1, decoder_channels)
9697

9798
self.decoder = MaskDecoder(
9899
num_multimask_outputs=3,
@@ -185,10 +186,11 @@ def forward(self, x):
185186
img_size = x.shape[-2:]
186187
x = torch.stack([self.preprocess(img) for img in x])
187188
features = self.encoder(x)
188-
sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
189+
# sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
190+
sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0))
189191
low_res_masks, iou_predictions = self.decoder(
190192
image_embeddings=features,
191-
image_pe=self.prompt_encoder.get_dense_pe(),
193+
image_pe=self._get_dense_pe(),
192194
sparse_prompt_embeddings=sparse_embeddings,
193195
dense_prompt_embeddings=dense_embeddings,
194196
multimask_output=self._decoder_multiclass_output,
@@ -198,3 +200,14 @@ def forward(self, x):
198200
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
199201
output = self.segmentation_head(masks)
200202
return output
203+
204+
def _get_dummy_promp_encoder_output(self, bs):
205+
"""Use this dummy output as we're training without prompts."""
206+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device)
207+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
208+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
209+
)
210+
return sparse_embeddings, dense_embeddings
211+
212+
def _get_dense_pe(self):
213+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)

segmentation_models_pytorch/encoders/sam.py

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ def __init__(self, name: str, **kwargs):
88
super().__init__(**kwargs)
99
self._name = name
1010
self._depth = kwargs["depth"]
11+
self._out_chans = kwargs.get("out_chans", 256)
12+
13+
@property
14+
def out_channels(self):
15+
return [-1, self._out_chans]
1116

1217

1318
sam_vit_encoders = {

0 commit comments

Comments
 (0)