Skip to content

Commit b963649

Browse files
author
Theiss Heilker
committed
Possible 1.5.1 incompatibility fix for elastic inference. See here https://docs.aws.amazon.com/elastic-inference/latest/developerguide/ei-pytorch-using.html
1 parent 6610a41 commit b963649

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import textwrap
1717

18-
import torch
18+
import torch, torcheia
1919
from sagemaker_inference import (
2020
content_types,
2121
decoder,
@@ -47,7 +47,9 @@ def default_model_fn(self, model_dir):
4747
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
4848
.format(DEFAULT_MODEL_FILENAME))
4949
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
50-
return torch.jit.load(model_path, map_location=torch.device('cpu'))
50+
model=torch.jit.load(model_path, map_location=torch.device('cpu'))
51+
# Attached EIA to model as required for PyTorch 1.5.1
52+
return torcheia.jit.attach_eia(model, 0)
5153
else:
5254
raise NotImplementedError(textwrap.dedent("""
5355
Please provide a model_fn implementation.

0 commit comments

Comments
 (0)