Skip to content

Commit 2ed775d

Browse files
author
Rustem Galiullin
committed
set unused sam modules to None
1 parent c1a9319 commit 2ed775d

File tree

1 file changed

+3
-5
lines changed
  • segmentation_models_pytorch/decoders/sam

1 file changed

+3
-5
lines changed

segmentation_models_pytorch/decoders/sam/model.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def __init__(
9393
input_image_size=(image_size, image_size),
9494
mask_in_chans=16,
9595
)
96+
self.prompt_encoder.point_embeddings = None
97+
self.prompt_encoder.mask_downscaling = None
98+
self.not_a_point_embed = None
9699

97100
self.decoder = MaskDecoder(
98101
num_multimask_outputs=3,
@@ -198,8 +201,3 @@ def forward(self, x):
198201
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
199202
output = self.segmentation_head(masks)
200203
return output
201-
202-
def train(self, mode: bool = True):
203-
super(SAM, self).train(mode)
204-
self.prompt_encoder.point_embeddings.requires_grad = False
205-
self.prompt_encoder.mask_downscaling.requires_grad = False

0 commit comments

Comments
 (0)