diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index cc6eadf1..c7d19a41 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -3,6 +3,8 @@ import torch from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder +from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom +from torch import nn from torch.nn import functional as F from torch.utils import model_zoo @@ -86,13 +88,12 @@ def __init__( out_chans=decoder_channels, ) + # this params are used instead of prompt_encoder image_embedding_size = image_size // vit_patch_size - self.prompt_encoder = PromptEncoder( - embed_dim=decoder_channels, - image_embedding_size=(image_embedding_size, image_embedding_size), - input_image_size=(image_size, image_size), - mask_in_chans=16, - ) + self.embed_dim = decoder_channels + self.image_embedding_size = (image_embedding_size, image_embedding_size) + self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2) + self.no_mask_embed = nn.Embedding(1, decoder_channels) self.decoder = MaskDecoder( num_multimask_outputs=3, @@ -185,10 +186,11 @@ def forward(self, x): img_size = x.shape[-2:] x = torch.stack([self.preprocess(img) for img in x]) features = self.encoder(x) - sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) low_res_masks, iou_predictions = self.decoder( image_embeddings=features, - image_pe=self.prompt_encoder.get_dense_pe(), + image_pe=self._get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=self._decoder_multiclass_output, @@ -198,3 +200,14 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output + + def _get_dummy_promp_encoder_output(self, bs): + """Use this dummy output as we're training without prompts.""" + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + return sparse_embeddings, dense_embeddings + + def _get_dense_pe(self): + return self.pe_layer(self.image_embedding_size).unsqueeze(0) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 4518111b..af86bfae 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -8,6 +8,11 @@ def __init__(self, name: str, **kwargs): super().__init__(**kwargs) self._name = name self._depth = kwargs["depth"] + self._out_chans = kwargs.get("out_chans", 256) + + @property + def out_channels(self): + return [-1, self._out_chans] sam_vit_encoders = {