Skip to content

Commit 4894d50

Browse files
authored
change: Enable default model fn for cpu and gpu (#107)
1 parent f498c2f commit 4894d50

File tree

12 files changed

+386
-10
lines changed

12 files changed

+386
-10
lines changed

buildspec.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ phases:
3939
- GENERIC_TAG="$FRAMEWORK_VERSION-pytorch-$BUILD_ID"
4040
- DLC_CPU_TAG="$FRAMEWORK_VERSION-dlc-cpu-$BUILD_ID"
4141
- DLC_GPU_TAG="$FRAMEWORK_VERSION-dlc-gpu-$BUILD_ID"
42-
- DLC_EIA_TAG="$FRAMEWORK_VERSION-dlc-eia-$BUILD_ID"
42+
- DLC_EIA_TAG="$EIA_FRAMEWORK_VERSION-dlc-eia-$BUILD_ID"
4343

4444
# run local CPU integration tests (build and push the image to ECR repo)
4545
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --build-image --push-image --dockerfile-type pytorch --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --aws-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --tag $GENERIC_TAG"

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
import os
16-
import textwrap
1716

1817
import torch
1918
from sagemaker_inference import (
@@ -29,9 +28,21 @@
2928
DEFAULT_MODEL_FILENAME = "model.pt"
3029

3130

31+
class ModelLoadError(Exception):
32+
pass
33+
34+
3235
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
3336
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)
3437

38+
@staticmethod
39+
def _is_model_file(filename):
40+
is_model_file = False
41+
if os.path.isfile(filename):
42+
_, ext = os.path.splitext(filename)
43+
is_model_file = ext in [".pt", ".pth"]
44+
return is_model_file
45+
3546
def default_model_fn(self, model_dir):
3647
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
3748
In other cases, users should provide customized model_fn() in script.
@@ -47,12 +58,30 @@ def default_model_fn(self, model_dir):
4758
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
4859
.format(DEFAULT_MODEL_FILENAME))
4960
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
50-
return torch.jit.load(model_path, map_location=torch.device('cpu'))
61+
try:
62+
return torch.jit.load(model_path, map_location=torch.device('cpu'))
63+
except RuntimeError as e:
64+
raise ModelLoadError(
65+
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
66+
) from e
5167
else:
52-
raise NotImplementedError(textwrap.dedent("""
53-
Please provide a model_fn implementation.
54-
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
55-
"""))
68+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69+
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
70+
if not os.path.exists(model_path):
71+
model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)]
72+
if len(model_files) != 1:
73+
raise ValueError(
74+
"Exactly one .pth or .pt file is required for PyTorch models: {}".format(model_files)
75+
)
76+
model_path = os.path.join(model_dir, model_files[0])
77+
try:
78+
model = torch.jit.load(model_path, map_location=device)
79+
except RuntimeError as e:
80+
raise ModelLoadError(
81+
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
82+
) from e
83+
model = model.to(device)
84+
return model
5685

5786
def default_input_fn(self, input_data, content_type):
5887
"""A default input_fn that can handle JSON, CSV and NPZ formats.

test/integration/__init__.py

+26
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818

1919
resources_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'resources'))
2020
mnist_path = os.path.join(resources_path, 'mnist')
21+
resnet18_path = os.path.join(resources_path, 'resnet18')
2122
data_dir = os.path.join(mnist_path, 'data')
2223
training_dir = os.path.join(data_dir, 'training')
2324
cpu_sub_dir = 'model_cpu'
2425
gpu_sub_dir = 'model_gpu'
2526
eia_sub_dir = 'model_eia'
2627
code_sub_dir = 'code'
28+
default_sub_dir = 'default_model'
29+
default_sub_eia_dir = 'default_model_eia'
30+
default_sub_traced_resnet_dir = 'default_traced_resnet'
2731

2832
model_cpu_dir = os.path.join(mnist_path, cpu_sub_dir)
2933
mnist_cpu_script = os.path.join(model_cpu_dir, code_sub_dir, 'mnist.py')
@@ -59,6 +63,28 @@
5963
"model_call_model_fn_once.tar.gz",
6064
script_path="code")
6165

66+
default_model_dir = os.path.join(resnet18_path, default_sub_dir)
67+
default_model_script = os.path.join(default_model_dir, code_sub_dir, "resnet18.py")
68+
default_model_tar = file_utils.make_tarfile(
69+
default_model_script, os.path.join(default_model_dir, "model.pt"), default_model_dir, script_path="code"
70+
)
71+
72+
default_traced_resnet_dir = os.path.join(resnet18_path, default_sub_traced_resnet_dir)
73+
default_traced_resnet_script = os.path.join(default_traced_resnet_dir, code_sub_dir, "resnet18.py")
74+
default_model_traced_resnet18_tar = file_utils.make_tarfile(
75+
default_traced_resnet_script,
76+
os.path.join(default_traced_resnet_dir, "traced_resnet18.pt"),
77+
default_traced_resnet_dir,
78+
filename="traced_resnet18.tar.gz",
79+
script_path="code",
80+
)
81+
82+
default_model_eia_dir = os.path.join(mnist_path, default_sub_eia_dir)
83+
default_model_eia_script = os.path.join(default_model_eia_dir, code_sub_dir, "mnist.py")
84+
default_model_eia_tar = file_utils.make_tarfile(
85+
default_model_eia_script, os.path.join(default_model_eia_dir, "model.pt"), default_model_eia_dir
86+
)
87+
6288
ROLE = 'dummy/unused-role'
6389
DEFAULT_TIMEOUT = 20
6490
PYTHON3 = 'py3'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import numpy as np
17+
import pytest
18+
import requests
19+
import sagemaker
20+
from sagemaker.predictor import RealTimePredictor
21+
from sagemaker.pytorch import PyTorchModel, PyTorchPredictor
22+
23+
from integration import (
24+
default_model_script,
25+
default_model_tar,
26+
default_traced_resnet_script,
27+
default_model_traced_resnet18_tar,
28+
default_model_eia_script,
29+
default_model_eia_tar,
30+
)
31+
from integration.sagemaker.timeout import timeout_and_delete_endpoint
32+
33+
34+
@pytest.mark.cpu_test
35+
def test_default_inference_cpu(sagemaker_session, image_uri, instance_type):
36+
instance_type = instance_type or "ml.c4.xlarge"
37+
# Scripted model is serialized with torch.jit.save().
38+
# Default inference test doesn't need to instantiate model definition
39+
_test_default_inference(
40+
sagemaker_session, image_uri, instance_type, default_model_tar, default_model_script
41+
)
42+
43+
44+
@pytest.mark.gpu_test
45+
def test_default_inference_gpu(sagemaker_session, image_uri, instance_type):
46+
instance_type = instance_type or "ml.p2.xlarge"
47+
# Scripted model is serialized with torch.jit.save().
48+
# Default inference test doesn't need to instantiate model definition
49+
_test_default_inference(
50+
sagemaker_session, image_uri, instance_type, default_model_tar, default_model_script
51+
)
52+
53+
54+
@pytest.mark.skip(
55+
reason="Latest EIA version - 1.5.1 uses mms. Enable when EIA images use torchserve"
56+
)
57+
@pytest.mark.eia_test
58+
def test_default_inference_eia(sagemaker_session, image_uri, instance_type, accelerator_type):
59+
instance_type = instance_type or "ml.c4.xlarge"
60+
# Scripted model is serialized with torch.jit.save().
61+
# Default inference test doesn't need to instantiate model definition
62+
_test_default_inference(
63+
sagemaker_session,
64+
image_uri,
65+
instance_type,
66+
default_model_eia_tar,
67+
default_model_eia_script,
68+
accelerator_type=accelerator_type,
69+
)
70+
71+
72+
@pytest.mark.gpu_test
73+
def test_default_inference_any_model_name_gpu(sagemaker_session, image_uri, instance_type):
74+
instance_type = instance_type or "ml.p2.xlarge"
75+
# Scripted model is serialized with torch.jit.save().
76+
# Default inference test doesn't need to instantiate model definition
77+
_test_default_inference(
78+
sagemaker_session,
79+
image_uri,
80+
instance_type,
81+
default_model_traced_resnet18_tar,
82+
default_traced_resnet_script,
83+
)
84+
85+
86+
def _test_default_inference(
87+
sagemaker_session, image_uri, instance_type, model_tar, mnist_script, accelerator_type=None
88+
):
89+
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-pytorch-serving")
90+
91+
model_data = sagemaker_session.upload_data(
92+
path=model_tar,
93+
key_prefix="sagemaker-pytorch-serving/models",
94+
)
95+
96+
pytorch = PyTorchModel(
97+
model_data=model_data,
98+
role="SageMakerRole",
99+
predictor_cls=RealTimePredictor if not accelerator_type else PyTorchPredictor,
100+
entry_point=mnist_script,
101+
image=image_uri,
102+
sagemaker_session=sagemaker_session,
103+
)
104+
with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=30):
105+
predictor = pytorch.deploy(
106+
initial_instance_count=1,
107+
instance_type=instance_type,
108+
accelerator_type=accelerator_type,
109+
endpoint_name=endpoint_name,
110+
)
111+
112+
if accelerator_type:
113+
batch_size = 100
114+
data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32)
115+
output = predictor.predict(data)
116+
assert output.shape == (batch_size, 10)
117+
else:
118+
image_url = (
119+
"https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/master/"
120+
"sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg"
121+
)
122+
img_data = requests.get(image_url).content
123+
with open("cat.jpg", "wb") as file_obj:
124+
file_obj.write(img_data)
125+
with open("cat.jpg", "rb") as f:
126+
payload = f.read()
127+
payload = bytearray(payload)
128+
response = predictor.predict(payload)
129+
result = json.loads(response.decode())
130+
assert len(result) == 1000

test/integration/sagemaker/test_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_mnist_gpu(sagemaker_session, image_uri, instance_type):
3434
_test_mnist_distributed(sagemaker_session, image_uri, instance_type, model_gpu_tar, mnist_gpu_script)
3535

3636

37-
@pytest.mark.skip(reason="Latest EIA version is too old - 1.3.1. Remove this after a new DLC release")
37+
@pytest.mark.skip(reason="Latest EIA version - 1.5.1 uses mms. Enable when EIA images use torchserve")
3838
@pytest.mark.eia_test
3939
def test_mnist_eia(sagemaker_session, image_uri, instance_type, accelerator_type):
4040
instance_type = instance_type or 'ml.c4.xlarge'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import logging
15+
import os
16+
import sys
17+
18+
import torch
19+
20+
logger = logging.getLogger(__name__)
21+
logger.setLevel(logging.DEBUG)
22+
logger.addHandler(logging.StreamHandler(sys.stdout))
23+
24+
25+
def predict_fn(input_data, model):
26+
logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape))
27+
# With EI, client instance should be CPU for cost-efficiency.
28+
# Sub-graphs with unsupported arguments run locally. Server runs with CUDA
29+
device = torch.device('cpu')
30+
model = model.to(device)
31+
input_data = input_data.to(device)
32+
with torch.no_grad():
33+
# Set the target device to the accelerator ordinal
34+
with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}):
35+
return model(input_data)
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import io
2+
import json
3+
import logging
4+
5+
import numpy as np
6+
import torch
7+
import torchvision.transforms as transforms
8+
from PIL import Image # Training container doesn't have this package
9+
10+
logger = logging.getLogger(__name__)
11+
logger.setLevel(logging.DEBUG)
12+
13+
14+
def transform_fn(model, payload, request_content_type, response_content_type):
15+
16+
logger.info("Invoking user-defined transform function")
17+
18+
if request_content_type and request_content_type != "application/octet-stream":
19+
raise RuntimeError(
20+
"Content type must be application/octet-stream. Provided: {0}".format(
21+
request_content_type
22+
)
23+
)
24+
25+
# preprocess
26+
decoded = Image.open(io.BytesIO(payload))
27+
preprocess = transforms.Compose(
28+
[
29+
transforms.Resize(256),
30+
transforms.CenterCrop(224),
31+
transforms.ToTensor(),
32+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33+
]
34+
)
35+
normalized = preprocess(decoded)
36+
batchified = normalized.unsqueeze(0)
37+
38+
# predict
39+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
batchified = batchified.to(device)
41+
result = model.forward(batchified)
42+
43+
# Softmax (assumes batch size 1)
44+
result = np.squeeze(result.cpu().detach().numpy())
45+
result_exp = np.exp(result - np.max(result))
46+
result = result_exp / np.sum(result_exp)
47+
48+
response_body = json.dumps(result.tolist())
49+
content_type = "application/json"
50+
51+
return response_body, content_type
44.7 MB
Binary file not shown.

0 commit comments

Comments
 (0)