We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4d1144e commit c1a9319Copy full SHA for c1a9319
segmentation_models_pytorch/decoders/sam/model.py
@@ -198,3 +198,8 @@ def forward(self, x):
198
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
199
output = self.segmentation_head(masks)
200
return output
201
+
202
+ def train(self, mode: bool = True):
203
+ super(SAM, self).train(mode)
204
+ self.prompt_encoder.point_embeddings.requires_grad = False
205
+ self.prompt_encoder.mask_downscaling.requires_grad = False
0 commit comments