@@ -93,7 +93,7 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta):
93
93
assert float (actual ) == pytest .approx (expected , eps )
94
94
95
95
96
- @torch .no_grad ()
96
+ @torch .inference_mode ()
97
97
def test_dice_loss_binary ():
98
98
eps = 1e-5
99
99
criterion = DiceLoss (mode = smp .losses .BINARY_MODE , from_logits = False )
@@ -131,7 +131,7 @@ def test_dice_loss_binary():
131
131
assert float (loss ) == pytest .approx (1.0 , abs = eps )
132
132
133
133
134
- @torch .no_grad ()
134
+ @torch .inference_mode ()
135
135
def test_tversky_loss_binary ():
136
136
eps = 1e-5
137
137
# with alpha=0.5; beta=0.5 it is equal to DiceLoss
@@ -172,7 +172,7 @@ def test_tversky_loss_binary():
172
172
assert float (loss ) == pytest .approx (1.0 , abs = eps )
173
173
174
174
175
- @torch .no_grad ()
175
+ @torch .inference_mode ()
176
176
def test_binary_jaccard_loss ():
177
177
eps = 1e-5
178
178
criterion = JaccardLoss (mode = smp .losses .BINARY_MODE , from_logits = False )
@@ -210,7 +210,7 @@ def test_binary_jaccard_loss():
210
210
assert float (loss ) == pytest .approx (1.0 , eps )
211
211
212
212
213
- @torch .no_grad ()
213
+ @torch .inference_mode ()
214
214
def test_multiclass_jaccard_loss ():
215
215
eps = 1e-5
216
216
criterion = JaccardLoss (mode = smp .losses .MULTICLASS_MODE , from_logits = False )
@@ -237,7 +237,7 @@ def test_multiclass_jaccard_loss():
237
237
assert float (loss ) == pytest .approx (1.0 - 1.0 / 3.0 , abs = eps )
238
238
239
239
240
- @torch .no_grad ()
240
+ @torch .inference_mode ()
241
241
def test_multilabel_jaccard_loss ():
242
242
eps = 1e-5
243
243
criterion = JaccardLoss (mode = smp .losses .MULTILABEL_MODE , from_logits = False )
@@ -263,7 +263,7 @@ def test_multilabel_jaccard_loss():
263
263
assert float (loss ) == pytest .approx (1.0 - 1.0 / 3.0 , abs = eps )
264
264
265
265
266
- @torch .no_grad ()
266
+ @torch .inference_mode ()
267
267
def test_soft_ce_loss ():
268
268
criterion = SoftCrossEntropyLoss (smooth_factor = 0.1 , ignore_index = - 100 )
269
269
@@ -276,7 +276,7 @@ def test_soft_ce_loss():
276
276
assert float (loss ) == pytest .approx (1.0125 , abs = 0.0001 )
277
277
278
278
279
- @torch .no_grad ()
279
+ @torch .inference_mode ()
280
280
def test_soft_bce_loss ():
281
281
criterion = SoftBCEWithLogitsLoss (smooth_factor = 0.1 , ignore_index = - 100 )
282
282
@@ -287,7 +287,7 @@ def test_soft_bce_loss():
287
287
assert float (loss ) == pytest .approx (0.7201 , abs = 0.0001 )
288
288
289
289
290
- @torch .no_grad ()
290
+ @torch .inference_mode ()
291
291
def test_binary_mcc_loss ():
292
292
eps = 1e-5
293
293
criterion = MCCLoss (eps = eps )
0 commit comments