Skip to content

Commit cbefac1

Browse files
Update no_grad usage to inference_mode
1 parent cf50cd0 commit cbefac1

14 files changed

+17
-17
lines changed

examples/binary_segmentation_buildings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def train_and_evaluate_one_epoch(
269269
# Set the model to evaluation mode
270270
model.eval()
271271
val_loss = 0
272-
with torch.no_grad():
272+
with torch.inference_mode():
273273
for batch in tqdm(valid_dataloader, desc="Evaluating"):
274274
images, masks = batch
275275
images, masks = images.to(device), masks.to(device)
@@ -325,7 +325,7 @@ def test_model(model, output_dir, test_dataloader, loss_fn, device):
325325
model.eval()
326326
test_loss = 0
327327
tp, fp, fn, tn = 0, 0, 0, 0
328-
with torch.no_grad():
328+
with torch.inference_mode():
329329
for batch in tqdm(test_dataloader, desc="Evaluating"):
330330
images, masks = batch
331331
images, masks = images.to(device), masks.to(device)

examples/binary_segmentation_intro.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@
10261026
],
10271027
"source": [
10281028
"batch = next(iter(test_dataloader))\n",
1029-
"with torch.no_grad():\n",
1029+
"with torch.inference_mode():\n",
10301030
" model.eval()\n",
10311031
" logits = model(batch[\"image\"])\n",
10321032
"pr_masks = logits.sigmoid()\n",

examples/camvid_segmentation_multiclass.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,7 @@
16831683
"images, masks = next(iter(test_loader))\n",
16841684
"\n",
16851685
"# Switch the model to evaluation mode\n",
1686-
"with torch.no_grad():\n",
1686+
"with torch.inference_mode():\n",
16871687
" model.eval()\n",
16881688
" logits = model(images) # Get raw logits from the model\n",
16891689
"\n",

examples/cars segmentation (camvid).ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@
12091209
],
12101210
"source": [
12111211
"images, masks = next(iter(test_loader))\n",
1212-
"with torch.no_grad():\n",
1212+
"with torch.inference_mode():\n",
12131213
" model.eval()\n",
12141214
" logits = model(images)\n",
12151215
"pr_masks = logits.sigmoid()\n",

examples/convert_to_onnx.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
],
190190
"source": [
191191
"# compute PyTorch output prediction\n",
192-
"with torch.no_grad():\n",
192+
"with torch.inference_mode():\n",
193193
" torch_out = model(sample)\n",
194194
"\n",
195195
"# compare ONNX Runtime and PyTorch results\n",

examples/dpt_inference_pretrained.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"input_tensor = input_tensor.to(device)\n",
7171
"\n",
7272
"# Perform inference\n",
73-
"with torch.no_grad():\n",
73+
"with torch.inference_mode():\n",
7474
" output_mask = model(input_tensor)\n",
7575
"\n",
7676
"# Postprocess mask\n",

examples/segformer_inference_pretrained.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"input_tensor = input_tensor.to(device)\n",
6464
"\n",
6565
"# Perform inference\n",
66-
"with torch.no_grad():\n",
66+
"with torch.inference_mode():\n",
6767
" output_mask = model(input_tensor)\n",
6868
"\n",
6969
"# Postprocess mask\n",

examples/upernet_inference_pretrained.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"input_tensor = input_tensor.to(device)\n",
8686
"\n",
8787
"# Perform inference\n",
88-
"with torch.no_grad():\n",
88+
"with torch.inference_mode():\n",
8989
" output_mask = model(input_tensor)\n",
9090
"\n",
9191
"# Postprocess mask\n",

misc/generate_test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def save_and_push(model, inputs, outputs, model_name, encoder_name):
5252
torch.manual_seed(423553)
5353
sample = torch.rand(1, 3, 256, 256)
5454

55-
with torch.no_grad():
55+
with torch.inference_mode():
5656
output = model(sample)
5757

5858
save_and_push(model, sample, output, model_name, encoder_name)

scripts/models-conversions/segformer-original-decoder-to-smp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def main(args):
107107
tensor = torch.tensor(normalized_image).permute(2, 0, 1).unsqueeze(0).float()
108108

109109
# Forward pass
110-
with torch.no_grad():
110+
with torch.inference_mode():
111111
mask = model(tensor)
112112

113113
# Postprocessing

scripts/models-conversions/upernet-hf-to-smp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def convert_model(model_name: str, push_to_hub: bool = False):
207207
print("Verifying model with test inference...")
208208
smp_model.eval()
209209
sample = torch.ones(1, 3, 512, 512)
210-
with torch.no_grad():
210+
with torch.inference_mode():
211211
output = smp_model(sample)
212212
print(f"Test inference successful. Output shape: {output.shape}")
213213

segmentation_models_pytorch/base/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def forward(self, x):
7474

7575
return masks
7676

77-
@torch.no_grad()
77+
@torch.inference_mode()
7878
def predict(self, x):
79-
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
79+
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.inference_mode()`
8080
8181
Args:
8282
x: 4D torch tensor with shape (batch_size, channels, height, width)

segmentation_models_pytorch/metrics/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_stats(
175175
return tp, fp, fn, tn
176176

177177

178-
@torch.no_grad()
178+
@torch.inference_mode()
179179
def _get_stats_multiclass(
180180
output: torch.LongTensor,
181181
target: torch.LongTensor,
@@ -221,7 +221,7 @@ def _get_stats_multiclass(
221221
return tp_count, fp_count, fn_count, tn_count
222222

223223

224-
@torch.no_grad()
224+
@torch.inference_mode()
225225
def _get_stats_multilabel(
226226
output: torch.LongTensor, target: torch.LongTensor
227227
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:

segmentation_models_pytorch/utils/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def on_epoch_start(self):
110110
self.model.eval()
111111

112112
def batch_update(self, x, y):
113-
with torch.no_grad():
113+
with torch.inference_mode():
114114
prediction = self.model.forward(x)
115115
loss = self.loss(prediction, y)
116116
return loss, prediction

0 commit comments

Comments
 (0)