Skip to content

Commit 05db6b5

Browse files
authored
Add ignore_index to Jaccard loss (#1151)
1 parent 30f19b2 commit 05db6b5

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

segmentation_models_pytorch/losses/jaccard.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
log_loss: bool = False,
1818
from_logits: bool = True,
1919
smooth: float = 0.0,
20+
ignore_index: Optional[int] = None,
2021
eps: float = 1e-7,
2122
):
2223
"""Jaccard loss for image segmentation task.
@@ -51,6 +52,7 @@ def __init__(
5152
self.classes = classes
5253
self.from_logits = from_logits
5354
self.smooth = smooth
55+
self.ignore_index = ignore_index
5456
self.eps = eps
5557
self.log_loss = log_loss
5658

@@ -74,17 +76,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7476
y_true = y_true.view(bs, 1, -1)
7577
y_pred = y_pred.view(bs, 1, -1)
7678

79+
if self.ignore_index is not None:
80+
mask = y_true != self.ignore_index
81+
y_pred = y_pred * mask
82+
y_true = y_true * mask
83+
7784
if self.mode == MULTICLASS_MODE:
7885
y_true = y_true.view(bs, -1)
7986
y_pred = y_pred.view(bs, num_classes, -1)
8087

81-
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
82-
y_true = y_true.permute(0, 2, 1) # H, C, H*W
88+
if self.ignore_index is not None:
89+
mask = y_true != self.ignore_index
90+
y_pred = y_pred * mask.unsqueeze(1)
91+
92+
y_true = F.one_hot(
93+
(y_true * mask).to(torch.long), num_classes
94+
) # N,H*W -> N,H*W, C
95+
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
96+
else:
97+
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
98+
y_true = y_true.permute(0, 2, 1) # N, C, H*W
8399

84100
if self.mode == MULTILABEL_MODE:
85101
y_true = y_true.view(bs, num_classes, -1)
86102
y_pred = y_pred.view(bs, num_classes, -1)
87103

104+
if self.ignore_index is not None:
105+
mask = y_true != self.ignore_index
106+
y_pred = y_pred * mask
107+
y_true = y_true * mask
108+
88109
scores = soft_jaccard_score(
89110
y_pred,
90111
y_true.type(y_pred.dtype),

0 commit comments

Comments
 (0)