diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 984005039f..5b89c2bebc 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "compiled.pt" - model = torch.jit.load(os.path.join(model_dir, "compiled.pt"), map_location=device) + # The compiled model is saved as "model.pth" + model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl")