File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed
segmentation_models_pytorch/decoders/sam Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -93,6 +93,9 @@ def __init__(
93
93
input_image_size = (image_size , image_size ),
94
94
mask_in_chans = 16 ,
95
95
)
96
+ self .prompt_encoder .point_embeddings = None
97
+ self .prompt_encoder .mask_downscaling = None
98
+ self .not_a_point_embed = None
96
99
97
100
self .decoder = MaskDecoder (
98
101
num_multimask_outputs = 3 ,
@@ -198,8 +201,3 @@ def forward(self, x):
198
201
masks = masks * iou_predictions .view (- 1 , masks .size (1 ), 1 , 1 )
199
202
output = self .segmentation_head (masks )
200
203
return output
201
-
202
- def train (self , mode : bool = True ):
203
- super (SAM , self ).train (mode )
204
- self .prompt_encoder .point_embeddings .requires_grad = False
205
- self .prompt_encoder .mask_downscaling .requires_grad = False
You can’t perform that action at this time.
0 commit comments