We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c1a9319 commit 2ed775dCopy full SHA for 2ed775d
segmentation_models_pytorch/decoders/sam/model.py
@@ -93,6 +93,9 @@ def __init__(
93
input_image_size=(image_size, image_size),
94
mask_in_chans=16,
95
)
96
+ self.prompt_encoder.point_embeddings = None
97
+ self.prompt_encoder.mask_downscaling = None
98
+ self.not_a_point_embed = None
99
100
self.decoder = MaskDecoder(
101
num_multimask_outputs=3,
@@ -198,8 +201,3 @@ def forward(self, x):
198
201
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
199
202
output = self.segmentation_head(masks)
200
203
return output
-
- def train(self, mode: bool = True):
- super(SAM, self).train(mode)
204
- self.prompt_encoder.point_embeddings.requires_grad = False
205
- self.prompt_encoder.mask_downscaling.requires_grad = False
0 commit comments