diff --git a/src/sagemaker_pytorch_serving_container/default_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_inference_handler.py index 176d3afc..f2533709 100644 --- a/src/sagemaker_pytorch_serving_container/default_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_inference_handler.py @@ -39,7 +39,8 @@ def default_model_fn(self, model_dir): if not os.path.exists(model_path): raise FileNotFoundError("Failed to load model with default model_fn: missing file {}." .format(DEFAULT_MODEL_FILENAME)) - return torch.jit.load(model_path) + # Client-framework is CPU only. But model will run in Elastic Inference server with CUDA. + return torch.jit.load(model_path, map_location=torch.device('cpu')) else: raise NotImplementedError(textwrap.dedent(""" Please provide a model_fn implementation.