Skip to content

Commit b2166ea

Browse files
committed
Use inference mode in tests
1 parent 0cab989 commit b2166ea

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

tests/models/test_segformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_load_pretrained(self):
2121

2222
sample = torch.ones([1, 3, 512, 512]).to(default_device)
2323

24-
with torch.no_grad():
24+
with torch.inference_mode():
2525
output = model(sample)
2626

2727
self.assertEqual(output.shape, (1, 150, 512, 512))

tests/test_losses.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta):
9393
assert float(actual) == pytest.approx(expected, eps)
9494

9595

96-
@torch.no_grad()
96+
@torch.inference_mode()
9797
def test_dice_loss_binary():
9898
eps = 1e-5
9999
criterion = DiceLoss(mode=smp.losses.BINARY_MODE, from_logits=False)
@@ -131,7 +131,7 @@ def test_dice_loss_binary():
131131
assert float(loss) == pytest.approx(1.0, abs=eps)
132132

133133

134-
@torch.no_grad()
134+
@torch.inference_mode()
135135
def test_tversky_loss_binary():
136136
eps = 1e-5
137137
# with alpha=0.5; beta=0.5 it is equal to DiceLoss
@@ -172,7 +172,7 @@ def test_tversky_loss_binary():
172172
assert float(loss) == pytest.approx(1.0, abs=eps)
173173

174174

175-
@torch.no_grad()
175+
@torch.inference_mode()
176176
def test_binary_jaccard_loss():
177177
eps = 1e-5
178178
criterion = JaccardLoss(mode=smp.losses.BINARY_MODE, from_logits=False)
@@ -210,7 +210,7 @@ def test_binary_jaccard_loss():
210210
assert float(loss) == pytest.approx(1.0, eps)
211211

212212

213-
@torch.no_grad()
213+
@torch.inference_mode()
214214
def test_multiclass_jaccard_loss():
215215
eps = 1e-5
216216
criterion = JaccardLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=False)
@@ -237,7 +237,7 @@ def test_multiclass_jaccard_loss():
237237
assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps)
238238

239239

240-
@torch.no_grad()
240+
@torch.inference_mode()
241241
def test_multilabel_jaccard_loss():
242242
eps = 1e-5
243243
criterion = JaccardLoss(mode=smp.losses.MULTILABEL_MODE, from_logits=False)
@@ -263,7 +263,7 @@ def test_multilabel_jaccard_loss():
263263
assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps)
264264

265265

266-
@torch.no_grad()
266+
@torch.inference_mode()
267267
def test_soft_ce_loss():
268268
criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100)
269269

@@ -276,7 +276,7 @@ def test_soft_ce_loss():
276276
assert float(loss) == pytest.approx(1.0125, abs=0.0001)
277277

278278

279-
@torch.no_grad()
279+
@torch.inference_mode()
280280
def test_soft_bce_loss():
281281
criterion = SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=-100)
282282

@@ -287,7 +287,7 @@ def test_soft_bce_loss():
287287
assert float(loss) == pytest.approx(0.7201, abs=0.0001)
288288

289289

290-
@torch.no_grad()
290+
@torch.inference_mode()
291291
def test_binary_mcc_loss():
292292
eps = 1e-5
293293
criterion = MCCLoss(eps=eps)

0 commit comments

Comments
 (0)