From c1a93198bf7df380f3d191e6c6ce01b73ebf9c14 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 10:08:28 +0400 Subject: [PATCH 1/3] set unused sam modules to require grad False --- segmentation_models_pytorch/decoders/sam/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index cc6eadf1..fec4f313 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -198,3 +198,8 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output + + def train(self, mode: bool = True): + super(SAM, self).train(mode) + self.prompt_encoder.point_embeddings.requires_grad = False + self.prompt_encoder.mask_downscaling.requires_grad = False From 2ed775d5df7e5142d4ee0c2afd3cbb1c888289f4 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 10:22:57 +0400 Subject: [PATCH 2/3] set unused sam modules to None --- segmentation_models_pytorch/decoders/sam/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index fec4f313..4091c282 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -93,6 +93,9 @@ def __init__( input_image_size=(image_size, image_size), mask_in_chans=16, ) + self.prompt_encoder.point_embeddings = None + self.prompt_encoder.mask_downscaling = None + self.not_a_point_embed = None self.decoder = MaskDecoder( num_multimask_outputs=3, @@ -198,8 +201,3 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output - - def train(self, mode: bool = True): - super(SAM, self).train(mode) - self.prompt_encoder.point_embeddings.requires_grad = False - self.prompt_encoder.mask_downscaling.requires_grad = False From 9c93eb435f69b6780da3a7258a08774f4187c00c Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 11:38:11 +0400 Subject: [PATCH 3/3] remove prompt encoder from sam --- .../decoders/sam/model.py | 32 ++++++++++++------- segmentation_models_pytorch/encoders/sam.py | 5 +++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 4091c282..c7d19a41 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -3,6 +3,8 @@ import torch from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder +from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom +from torch import nn from torch.nn import functional as F from torch.utils import model_zoo @@ -86,16 +88,12 @@ def __init__( out_chans=decoder_channels, ) + # this params are used instead of prompt_encoder image_embedding_size = image_size // vit_patch_size - self.prompt_encoder = PromptEncoder( - embed_dim=decoder_channels, - image_embedding_size=(image_embedding_size, image_embedding_size), - input_image_size=(image_size, image_size), - mask_in_chans=16, - ) - self.prompt_encoder.point_embeddings = None - self.prompt_encoder.mask_downscaling = None - self.not_a_point_embed = None + self.embed_dim = decoder_channels + self.image_embedding_size = (image_embedding_size, image_embedding_size) + self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2) + self.no_mask_embed = nn.Embedding(1, decoder_channels) self.decoder = MaskDecoder( num_multimask_outputs=3, @@ -188,10 +186,11 @@ def forward(self, x): img_size = x.shape[-2:] x = torch.stack([self.preprocess(img) for img in x]) features = self.encoder(x) - sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) low_res_masks, iou_predictions = self.decoder( image_embeddings=features, - image_pe=self.prompt_encoder.get_dense_pe(), + image_pe=self._get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=self._decoder_multiclass_output, @@ -201,3 +200,14 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output + + def _get_dummy_promp_encoder_output(self, bs): + """Use this dummy output as we're training without prompts.""" + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + return sparse_embeddings, dense_embeddings + + def _get_dense_pe(self): + return self.pe_layer(self.image_embedding_size).unsqueeze(0) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 4518111b..af86bfae 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -8,6 +8,11 @@ def __init__(self, name: str, **kwargs): super().__init__(**kwargs) self._name = name self._depth = kwargs["depth"] + self._out_chans = kwargs.get("out_chans", 256) + + @property + def out_channels(self): + return [-1, self._out_chans] sam_vit_encoders = {