Skip to content

Commit 29e1665

Browse files
author
Dan
authored
Merge branch 'master' into fix/typo-create_monitoring_schedule
2 parents 52101c6 + 2a0cf1b commit 29e1665

File tree

9 files changed

+1358
-9
lines changed

9 files changed

+1358
-9
lines changed

src/sagemaker/image_uri_config/neo-pytorch.json

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,73 @@
22
"processors": ["cpu", "gpu"],
33
"scope": ["inference"],
44
"version_aliases": {
5-
"0.4.0": "1.4.0",
6-
"1.0.0": "1.4.0",
7-
"1.1.0": "1.4.0",
8-
"1.2.0": "1.4.0",
9-
"1.3.0": "1.4.0"
5+
"0.4.0": "1.4",
6+
"1.0.0": "1.4",
7+
"1.1.0": "1.4",
8+
"1.2.0": "1.4",
9+
"1.3.0": "1.4",
10+
"1.4.0": "1.4"
1011
},
1112
"versions": {
12-
"1.4.0": {
13+
"1.4": {
14+
"py_versions": ["py3"],
15+
"registries": {
16+
"af-south-1": "774647643957",
17+
"ap-east-1": "110948597952",
18+
"ap-northeast-1": "941853720454",
19+
"ap-northeast-2": "151534178276",
20+
"ap-south-1": "763008648453",
21+
"ap-southeast-1": "324986816169",
22+
"ap-southeast-2": "355873309152",
23+
"ca-central-1": "464438896020",
24+
"cn-north-1": "472730292857",
25+
"cn-northwest-1": "474822919863",
26+
"eu-central-1": "746233611703",
27+
"eu-north-1": "601324751636",
28+
"eu-south-1": "966458181534",
29+
"eu-west-1": "802834080501",
30+
"eu-west-2": "205493899709",
31+
"eu-west-3": "254080097072",
32+
"me-south-1": "836785723513",
33+
"sa-east-1": "756306329178",
34+
"us-east-1": "785573368785",
35+
"us-east-2": "007439368137",
36+
"us-gov-west-1": "263933020539",
37+
"us-west-1": "710691900526",
38+
"us-west-2": "301217895009"
39+
},
40+
"repository": "sagemaker-inference-pytorch"
41+
},
42+
"1.5": {
43+
"py_versions": ["py3"],
44+
"registries": {
45+
"af-south-1": "774647643957",
46+
"ap-east-1": "110948597952",
47+
"ap-northeast-1": "941853720454",
48+
"ap-northeast-2": "151534178276",
49+
"ap-south-1": "763008648453",
50+
"ap-southeast-1": "324986816169",
51+
"ap-southeast-2": "355873309152",
52+
"ca-central-1": "464438896020",
53+
"cn-north-1": "472730292857",
54+
"cn-northwest-1": "474822919863",
55+
"eu-central-1": "746233611703",
56+
"eu-north-1": "601324751636",
57+
"eu-south-1": "966458181534",
58+
"eu-west-1": "802834080501",
59+
"eu-west-2": "205493899709",
60+
"eu-west-3": "254080097072",
61+
"me-south-1": "836785723513",
62+
"sa-east-1": "756306329178",
63+
"us-east-1": "785573368785",
64+
"us-east-2": "007439368137",
65+
"us-gov-west-1": "263933020539",
66+
"us-west-1": "710691900526",
67+
"us-west-2": "301217895009"
68+
},
69+
"repository": "sagemaker-inference-pytorch"
70+
},
71+
"1.6": {
1372
"py_versions": ["py3"],
1473
"registries": {
1574
"af-south-1": "774647643957",

src/sagemaker/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
import re
1920

2021
import sagemaker
2122
from sagemaker import (
@@ -398,6 +399,7 @@ def _compilation_job_config(
398399
target_platform_arch=None,
399400
target_platform_accelerator=None,
400401
compiler_options=None,
402+
framework_version=None,
401403
):
402404
"""Placeholder Docstring"""
403405
input_model_config = {
@@ -407,6 +409,14 @@ def _compilation_job_config(
407409
else input_shape,
408410
"Framework": framework.upper(),
409411
}
412+
413+
if (
414+
framework.lower() == "pytorch"
415+
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
416+
and framework_version is not None
417+
):
418+
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
419+
410420
role = self.sagemaker_session.expand_role(role)
411421
output_model_config = {
412422
"S3OutputLocation": output_path,
@@ -572,7 +582,8 @@ def compile(
572582
framework (str): The framework that is used to train the original
573583
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
574584
'onnx', 'xgboost'
575-
framework_version (str):
585+
framework_version (str): The version of framework, for example:
586+
'1.5' for PyTorch
576587
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
577588
For allowed strings see
578589
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
@@ -626,11 +637,11 @@ def compile(
626637
target_platform_arch,
627638
target_platform_accelerator,
628639
compiler_options,
640+
framework_version,
629641
)
630642
self.sagemaker_session.compile_model(**config)
631643
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
632644
self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"]
633-
634645
if target_instance_family is not None:
635646
if target_instance_family.startswith("ml_"):
636647
self.image_uri = self._compilation_image_uri(

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,26 @@ def pytorch_eia_py_version():
182182
return "py3"
183183

184184

185+
@pytest.fixture(scope="module")
186+
def neo_pytorch_latest_py_version():
187+
return "py3"
188+
189+
190+
@pytest.fixture(scope="module")
191+
def neo_pytorch_compilation_job_name():
192+
return utils.name_from_base("pytorch-neo-model")
193+
194+
195+
@pytest.fixture(scope="module")
196+
def neo_pytorch_target_device():
197+
return "ml_c5"
198+
199+
200+
@pytest.fixture(scope="module")
201+
def neo_pytorch_cpu_instance_type():
202+
return "ml.c5.xlarge"
203+
204+
185205
@pytest.fixture(scope="module")
186206
def xgboost_framework_version(xgboost_version):
187207
if xgboost_version in ("1", "latest"):

tests/data/pytorch_neo/cat.jpg

114 KB
Loading
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import io
14+
import json
15+
import logging
16+
import os
17+
import pickle
18+
19+
import numpy as np
20+
import torch
21+
import neopytorch
22+
import torchvision.transforms as transforms
23+
from PIL import Image # Training container doesn't have this package
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.DEBUG)
27+
28+
29+
def transform_fn(model, payload, request_content_type, response_content_type):
30+
31+
logger.info("Invoking user-defined transform function")
32+
33+
if request_content_type != "application/octet-stream":
34+
raise RuntimeError(
35+
"Content type must be application/octet-stream. Provided: {0}".format(
36+
request_content_type
37+
)
38+
)
39+
40+
# preprocess image
41+
decoded = Image.open(io.BytesIO(payload))
42+
preprocess = transforms.Compose(
43+
[
44+
transforms.Resize(256),
45+
transforms.CenterCrop(224),
46+
transforms.ToTensor(),
47+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48+
]
49+
)
50+
normalized = preprocess(decoded)
51+
batchified = normalized.unsqueeze(0)
52+
53+
# predict
54+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55+
batchified = batchified.to(device)
56+
result = model.forward(batchified)
57+
58+
# Softmax (assumes batch size 1)
59+
result = np.squeeze(result.cpu().numpy())
60+
result_exp = np.exp(result - np.max(result))
61+
result = result_exp / np.sum(result_exp)
62+
63+
response_body = json.dumps(result.tolist())
64+
content_type = "application/json"
65+
66+
return response_body, content_type
67+
68+
69+
def model_fn(model_dir):
70+
71+
logger.info("model_fn")
72+
neopytorch.config(model_dir=model_dir, neo_runtime=True)
73+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74+
# The compiled model is saved as "compiled.pt"
75+
model = torch.jit.load(os.path.join(model_dir, "compiled.pt"), map_location=device)
76+
77+
# It is recommended to run warm-up inference during model load
78+
sample_input_path = os.path.join(model_dir, "sample_input.pkl")
79+
with open(sample_input_path, "rb") as input_file:
80+
model_input = pickle.load(input_file)
81+
if torch.is_tensor(model_input):
82+
model_input = model_input.to(device)
83+
model(model_input)
84+
elif isinstance(model_input, tuple):
85+
model_input = (inp.to(device) for inp in model_input if torch.is_tensor(inp))
86+
model(*model_input)
87+
else:
88+
print("Only supports a torch tensor or a tuple of torch tensors")
89+
90+
return model

0 commit comments

Comments
 (0)