Skip to content

Commit 4d1144e

Browse files
author
Rustem Galiullin
committed
use iou scaling to avoid errors with torch ddp
1 parent 64a2516 commit 4d1144e

File tree

1 file changed

+2
-0
lines changed
  • segmentation_models_pytorch/decoders/sam

1 file changed

+2
-0
lines changed

segmentation_models_pytorch/decoders/sam/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -194,5 +194,7 @@ def forward(self, x):
194194
multimask_output=self._decoder_multiclass_output,
195195
)
196196
masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size)
197+
# use scaling below in order to make it work with torch DDP
198+
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
197199
output = self.segmentation_head(masks)
198200
return output

0 commit comments

Comments
 (0)