Skip to content

Commit a32893b

Browse files
author
Deng
committed
feature: pytorch 1.3.1 eia
1 parent 2199e50 commit a32893b

File tree

7 files changed

+93
-3
lines changed

7 files changed

+93
-3
lines changed

src/sagemaker/fw_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@
5353
)
5454

5555
VALID_PY_VERSIONS = ["py2", "py3"]
56-
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
56+
VALID_EIA_FRAMEWORKS = [
57+
"tensorflow",
58+
"tensorflow-serving",
59+
"mxnet",
60+
"mxnet-serving",
61+
"pytorch",
62+
"pytorch-serving",
63+
]
5764
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
5865
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"}
5966
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"}
@@ -71,6 +78,7 @@
7178
"mxnet-serving-eia": "mxnet-inference-eia",
7279
"pytorch": "pytorch-training",
7380
"pytorch-serving": "pytorch-inference",
81+
"pytorch-serving-eia": "pytorch-inference-eia",
7482
}
7583

7684
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
@@ -82,6 +90,7 @@
8290
"mxnet-serving-eia": [1, 4, 1],
8391
"pytorch": [1, 2, 0],
8492
"pytorch-serving": [1, 2, 0],
93+
"pytorch-serving-eia": {"py3": [1, 3, 1]},
8594
}
8695

8796
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
@@ -117,6 +126,8 @@ def _is_dlc_version(framework, framework_version, py_version):
117126
"""
118127
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
119128
if isinstance(lowest_version_list, dict):
129+
if py_version not in lowest_version_list:
130+
raise ValueError("{} is not supported in {}.".format(framework, py_version))
120131
lowest_version_list = lowest_version_list[py_version]
121132

122133
if lowest_version_list:

src/sagemaker/pytorch/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ With PyTorch Estimators and Models, you can train and host PyTorch models on Ama
66

77
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``, ``1.4.0``.
88

9+
Supported versions of TensorFlow for Elastic Inference: ``1.3.1``.
10+
911
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
1012

1113
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.

tests/data/pytorch_eia/mnist.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# This file is intentionally left blank to invoke default model_fn and predict_fn
129 KB
Binary file not shown.

tests/integ/test_pytorch_train.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
2828
MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py")
2929

30+
EIA_DIR = os.path.join(DATA_DIR, "pytorch_eia")
31+
EIA_MODEL = os.path.join(EIA_DIR, "model_mnist.tar.gz")
32+
EIA_SCRIPT = os.path.join(EIA_DIR, "mnist.py")
33+
3034

3135
@pytest.fixture(scope="module", name="pytorch_training_job")
3236
def fixture_training_job(sagemaker_session, pytorch_full_version, cpu_instance_type):
@@ -115,6 +119,35 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
115119
assert output.shape == (batch_size, 10)
116120

117121

122+
@pytest.mark.skipif(
123+
PYTHON_VERSION == "py2",
124+
reason="PyTorch EIA does not support Python 2.",
125+
)
126+
def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
127+
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
128+
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
129+
pytorch = PyTorchModel(
130+
model_data,
131+
"SageMakerRole",
132+
framework_version="1.3.1",
133+
entry_point=EIA_SCRIPT,
134+
sagemaker_session=sagemaker_session,
135+
)
136+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
137+
predictor = pytorch.deploy(
138+
initial_instance_count=1,
139+
instance_type=cpu_instance_type,
140+
accelerator_type="ml.eia1.large",
141+
endpoint_name=endpoint_name,
142+
)
143+
144+
batch_size = 100
145+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
146+
output = predictor.predict(data)
147+
148+
assert output.shape == (batch_size, 10)
149+
150+
118151
def _upload_training_data(pytorch):
119152
return pytorch.sagemaker_session.upload_data(
120153
path=os.path.join(MNIST_DIR, "training"),

tests/unit/test_fw_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,37 @@ def test_mxnet_eia_images():
311311
)
312312

313313

314+
def test_pytorch_eia_images():
315+
image_uri = fw_utils.create_image_uri(
316+
"us-east-1",
317+
"pytorch-serving",
318+
"ml.c4.2xlarge",
319+
"1.3.1",
320+
"py3",
321+
accelerator_type="ml.eia1.large",
322+
)
323+
assert (
324+
image_uri
325+
== "{}.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-eia:1.3.1-cpu-py3".format(
326+
fw_utils.ASIMOV_PROD_ACCOUNT
327+
)
328+
)
329+
330+
331+
def test_pytorch_eia_py2_error():
332+
error_message = "pytorch-serving-eia is not supported in py2."
333+
with pytest.raises(ValueError) as error:
334+
fw_utils.create_image_uri(
335+
"us-east-1",
336+
"pytorch-serving",
337+
"ml.c4.2xlarge",
338+
"1.3.1",
339+
"py2",
340+
accelerator_type="ml.eia1.large",
341+
)
342+
assert error_message in str(error)
343+
344+
314345
def test_create_image_uri_override_account():
315346
image_uri = fw_utils.create_image_uri(
316347
"us-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"

tests/unit/test_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ def test_model_image_accelerator(sagemaker_session):
348348
model = PyTorchModel(
349349
MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session
350350
)
351-
with pytest.raises(ValueError):
352-
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
351+
predictor = model.deploy(1, CPU)
352+
assert isinstance(predictor, PyTorchPredictor)
353353

354354

355355
def test_train_image_default(sagemaker_session):

0 commit comments

Comments
 (0)