Skip to content

Commit c1a9319

Browse files
author
Rustem Galiullin
committed
set unused sam modules to require grad False
1 parent 4d1144e commit c1a9319

File tree

1 file changed

+5
-0
lines changed
  • segmentation_models_pytorch/decoders/sam

1 file changed

+5
-0
lines changed

segmentation_models_pytorch/decoders/sam/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,8 @@ def forward(self, x):
198198
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
199199
output = self.segmentation_head(masks)
200200
return output
201+
202+
def train(self, mode: bool = True):
203+
super(SAM, self).train(mode)
204+
self.prompt_encoder.point_embeddings.requires_grad = False
205+
self.prompt_encoder.mask_downscaling.requires_grad = False

0 commit comments

Comments
 (0)