diff --git a/examples/binary_segmentation_buildings.py b/examples/binary_segmentation_buildings.py index 33636477..1dd2cf0a 100644 --- a/examples/binary_segmentation_buildings.py +++ b/examples/binary_segmentation_buildings.py @@ -269,7 +269,7 @@ def train_and_evaluate_one_epoch( # Set the model to evaluation mode model.eval() val_loss = 0 - with torch.no_grad(): + with torch.inference_mode(): for batch in tqdm(valid_dataloader, desc="Evaluating"): images, masks = batch images, masks = images.to(device), masks.to(device) @@ -325,7 +325,7 @@ def test_model(model, output_dir, test_dataloader, loss_fn, device): model.eval() test_loss = 0 tp, fp, fn, tn = 0, 0, 0, 0 - with torch.no_grad(): + with torch.inference_mode(): for batch in tqdm(test_dataloader, desc="Evaluating"): images, masks = batch images, masks = images.to(device), masks.to(device) diff --git a/examples/binary_segmentation_intro.ipynb b/examples/binary_segmentation_intro.ipynb index c69dd697..bbdf329d 100644 --- a/examples/binary_segmentation_intro.ipynb +++ b/examples/binary_segmentation_intro.ipynb @@ -1026,7 +1026,7 @@ ], "source": [ "batch = next(iter(test_dataloader))\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " model.eval()\n", " logits = model(batch[\"image\"])\n", "pr_masks = logits.sigmoid()\n", diff --git a/examples/camvid_segmentation_multiclass.ipynb b/examples/camvid_segmentation_multiclass.ipynb index c918167b..43763df8 100644 --- a/examples/camvid_segmentation_multiclass.ipynb +++ b/examples/camvid_segmentation_multiclass.ipynb @@ -1683,7 +1683,7 @@ "images, masks = next(iter(test_loader))\n", "\n", "# Switch the model to evaluation mode\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " model.eval()\n", " logits = model(images) # Get raw logits from the model\n", "\n", diff --git a/examples/cars segmentation (camvid).ipynb b/examples/cars segmentation (camvid).ipynb index 00c22b31..a9b41a68 100644 --- a/examples/cars segmentation (camvid).ipynb +++ b/examples/cars segmentation (camvid).ipynb @@ -1209,7 +1209,7 @@ ], "source": [ "images, masks = next(iter(test_loader))\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " model.eval()\n", " logits = model(images)\n", "pr_masks = logits.sigmoid()\n", diff --git a/examples/convert_to_onnx.ipynb b/examples/convert_to_onnx.ipynb index abd063a0..fc34d9b5 100644 --- a/examples/convert_to_onnx.ipynb +++ b/examples/convert_to_onnx.ipynb @@ -189,7 +189,7 @@ ], "source": [ "# compute PyTorch output prediction\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " torch_out = model(sample)\n", "\n", "# compare ONNX Runtime and PyTorch results\n", diff --git a/examples/dpt_inference_pretrained.ipynb b/examples/dpt_inference_pretrained.ipynb index adfb5a15..e7365f3b 100644 --- a/examples/dpt_inference_pretrained.ipynb +++ b/examples/dpt_inference_pretrained.ipynb @@ -70,7 +70,7 @@ "input_tensor = input_tensor.to(device)\n", "\n", "# Perform inference\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " output_mask = model(input_tensor)\n", "\n", "# Postprocess mask\n", diff --git a/examples/segformer_inference_pretrained.ipynb b/examples/segformer_inference_pretrained.ipynb index d2d195fd..4ea44987 100644 --- a/examples/segformer_inference_pretrained.ipynb +++ b/examples/segformer_inference_pretrained.ipynb @@ -63,7 +63,7 @@ "input_tensor = input_tensor.to(device)\n", "\n", "# Perform inference\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " output_mask = model(input_tensor)\n", "\n", "# Postprocess mask\n", diff --git a/examples/upernet_inference_pretrained.ipynb b/examples/upernet_inference_pretrained.ipynb index aa644858..85512595 100644 --- a/examples/upernet_inference_pretrained.ipynb +++ b/examples/upernet_inference_pretrained.ipynb @@ -85,7 +85,7 @@ "input_tensor = input_tensor.to(device)\n", "\n", "# Perform inference\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " output_mask = model(input_tensor)\n", "\n", "# Postprocess mask\n", diff --git a/misc/generate_test_models.py b/misc/generate_test_models.py index a26cbc66..0422f230 100644 --- a/misc/generate_test_models.py +++ b/misc/generate_test_models.py @@ -52,7 +52,7 @@ def save_and_push(model, inputs, outputs, model_name, encoder_name): torch.manual_seed(423553) sample = torch.rand(1, 3, 256, 256) - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) save_and_push(model, sample, output, model_name, encoder_name) diff --git a/scripts/models-conversions/segformer-original-decoder-to-smp.py b/scripts/models-conversions/segformer-original-decoder-to-smp.py index e433c256..a91c6fc9 100644 --- a/scripts/models-conversions/segformer-original-decoder-to-smp.py +++ b/scripts/models-conversions/segformer-original-decoder-to-smp.py @@ -107,7 +107,7 @@ def main(args): tensor = torch.tensor(normalized_image).permute(2, 0, 1).unsqueeze(0).float() # Forward pass - with torch.no_grad(): + with torch.inference_mode(): mask = model(tensor) # Postprocessing diff --git a/scripts/models-conversions/upernet-hf-to-smp.py b/scripts/models-conversions/upernet-hf-to-smp.py index 8cd3162f..08f4c224 100644 --- a/scripts/models-conversions/upernet-hf-to-smp.py +++ b/scripts/models-conversions/upernet-hf-to-smp.py @@ -207,7 +207,7 @@ def convert_model(model_name: str, push_to_hub: bool = False): print("Verifying model with test inference...") smp_model.eval() sample = torch.ones(1, 3, 512, 512) - with torch.no_grad(): + with torch.inference_mode(): output = smp_model(sample) print(f"Test inference successful. Output shape: {output.shape}") diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 9b0db714..71322cf0 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -74,9 +74,9 @@ def forward(self, x): return masks - @torch.no_grad() + @torch.inference_mode() def predict(self, x): - """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` + """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.inference_mode()` Args: x: 4D torch tensor with shape (batch_size, channels, height, width) diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py index c0755787..5fd75cad 100644 --- a/segmentation_models_pytorch/metrics/functional.py +++ b/segmentation_models_pytorch/metrics/functional.py @@ -175,7 +175,7 @@ def get_stats( return tp, fp, fn, tn -@torch.no_grad() +@torch.inference_mode() def _get_stats_multiclass( output: torch.LongTensor, target: torch.LongTensor, @@ -221,7 +221,7 @@ def _get_stats_multiclass( return tp_count, fp_count, fn_count, tn_count -@torch.no_grad() +@torch.inference_mode() def _get_stats_multilabel( output: torch.LongTensor, target: torch.LongTensor ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: diff --git a/segmentation_models_pytorch/utils/train.py b/segmentation_models_pytorch/utils/train.py index 8c087c6b..a7b8e63b 100644 --- a/segmentation_models_pytorch/utils/train.py +++ b/segmentation_models_pytorch/utils/train.py @@ -110,7 +110,7 @@ def on_epoch_start(self): self.model.eval() def batch_update(self, x, y): - with torch.no_grad(): + with torch.inference_mode(): prediction = self.model.forward(x) loss = self.loss(prediction, y) return loss, prediction