@@ -17,6 +17,7 @@ def __init__(
17
17
log_loss : bool = False ,
18
18
from_logits : bool = True ,
19
19
smooth : float = 0.0 ,
20
+ ignore_index : Optional [int ] = None ,
20
21
eps : float = 1e-7 ,
21
22
):
22
23
"""Jaccard loss for image segmentation task.
@@ -51,6 +52,7 @@ def __init__(
51
52
self .classes = classes
52
53
self .from_logits = from_logits
53
54
self .smooth = smooth
55
+ self .ignore_index = ignore_index
54
56
self .eps = eps
55
57
self .log_loss = log_loss
56
58
@@ -74,17 +76,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
74
76
y_true = y_true .view (bs , 1 , - 1 )
75
77
y_pred = y_pred .view (bs , 1 , - 1 )
76
78
79
+ if self .ignore_index is not None :
80
+ mask = y_true != self .ignore_index
81
+ y_pred = y_pred * mask
82
+ y_true = y_true * mask
83
+
77
84
if self .mode == MULTICLASS_MODE :
78
85
y_true = y_true .view (bs , - 1 )
79
86
y_pred = y_pred .view (bs , num_classes , - 1 )
80
87
81
- y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
82
- y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
88
+ if self .ignore_index is not None :
89
+ mask = y_true != self .ignore_index
90
+ y_pred = y_pred * mask .unsqueeze (1 )
91
+
92
+ y_true = F .one_hot (
93
+ (y_true * mask ).to (torch .long ), num_classes
94
+ ) # N,H*W -> N,H*W, C
95
+ y_true = y_true .permute (0 , 2 , 1 ) * mask .unsqueeze (1 ) # N, C, H*W
96
+ else :
97
+ y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
98
+ y_true = y_true .permute (0 , 2 , 1 ) # N, C, H*W
83
99
84
100
if self .mode == MULTILABEL_MODE :
85
101
y_true = y_true .view (bs , num_classes , - 1 )
86
102
y_pred = y_pred .view (bs , num_classes , - 1 )
87
103
104
+ if self .ignore_index is not None :
105
+ mask = y_true != self .ignore_index
106
+ y_pred = y_pred * mask
107
+ y_true = y_true * mask
108
+
88
109
scores = soft_jaccard_score (
89
110
y_pred ,
90
111
y_true .type (y_pred .dtype ),
0 commit comments