Skip to content

Commit 4ef5386

Browse files
authored
feature: pytorch 1.3.1 eia support (#1328)
* feature: pytorch 1.3.1 eia
1 parent 7c31493 commit 4ef5386

File tree

10 files changed

+128
-11
lines changed

10 files changed

+128
-11
lines changed

README.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ TensorFlow SageMaker Estimators
176176

177177
By using TensorFlow SageMaker Estimators, you can train and host TensorFlow models on Amazon SageMaker.
178178

179-
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``, ``1.14.0``, ``1.15.0``, ``2.0.0``.
179+
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``, ``1.14.0``, ``1.15.0``, ``1.15.2``, ``2.0.0``, ``2.0.1``, ``2.1.0``.
180180

181181
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``, ``1.13.1``, ``1.14.0``.
182182

@@ -208,7 +208,9 @@ PyTorch SageMaker Estimators
208208

209209
With PyTorch SageMaker Estimators, you can train and host PyTorch models on Amazon SageMaker.
210210

211-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``.
211+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``, ``1.4.0``.
212+
213+
Supported versions of PyTorch for Elastic Inference: ``1.3.1``.
212214

213215
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
214216

doc/using_pytorch.rst

+15
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ With PyTorch Estimators and Models, you can train and host PyTorch models on Ama
66

77
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``.
88

9+
Supported versions of PyTorch for Elastic Inference: ``1.3.1``.
10+
911
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
1012

1113
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.
@@ -250,6 +252,14 @@ You use the SageMaker PyTorch model server to host your PyTorch model when you c
250252
Estimator. The model server runs inside a SageMaker Endpoint, which your call to ``deploy`` creates.
251253
You can access the name of the Endpoint by the ``name`` property on the returned ``Predictor``.
252254

255+
PyTorch on Amazon SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance.
256+
In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
257+
258+
.. code:: python
259+
260+
predictor = pytorch_estimator.deploy(instance_type='ml.m4.xlarge',
261+
initial_instance_count=1,
262+
accelerator_type='ml.eia2.medium')
253263
254264
The SageMaker PyTorch Model Server
255265
==================================
@@ -291,6 +301,11 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
291301
model.load_state_dict(torch.load(f))
292302
return model
293303
304+
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
305+
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
306+
yor parameter file as ``model.pt`` instead of ``model.pth``. For more information on inference script, please refer to:
307+
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
308+
294309
Serve a PyTorch Model
295310
---------------------
296311

src/sagemaker/fw_utils.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@
5353
)
5454

5555
VALID_PY_VERSIONS = ["py2", "py3"]
56-
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
56+
VALID_EIA_FRAMEWORKS = [
57+
"tensorflow",
58+
"tensorflow-serving",
59+
"mxnet",
60+
"mxnet-serving",
61+
"pytorch-serving",
62+
]
63+
PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"]
5764
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
5865
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"}
5966
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"}
@@ -71,6 +78,7 @@
7178
"mxnet-serving-eia": "mxnet-inference-eia",
7279
"pytorch": "pytorch-training",
7380
"pytorch-serving": "pytorch-inference",
81+
"pytorch-serving-eia": "pytorch-inference-eia",
7482
}
7583

7684
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
@@ -82,6 +90,7 @@
8290
"mxnet-serving-eia": [1, 4, 1],
8391
"pytorch": [1, 2, 0],
8492
"pytorch-serving": [1, 2, 0],
93+
"pytorch-serving-eia": [1, 3, 1],
8594
}
8695

8796
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
@@ -207,6 +216,7 @@ def create_image_uri(
207216

208217
if _accelerator_type_valid_for_framework(
209218
framework=framework,
219+
py_version=py_version,
210220
accelerator_type=accelerator_type,
211221
optimized_families=optimized_families,
212222
):
@@ -259,21 +269,27 @@ def create_image_uri(
259269

260270

261271
def _accelerator_type_valid_for_framework(
262-
framework, accelerator_type=None, optimized_families=None
272+
framework, py_version, accelerator_type=None, optimized_families=None
263273
):
264274
"""
265275
Args:
266276
framework:
277+
py_version:
267278
accelerator_type:
268279
optimized_families:
269280
"""
270281
if accelerator_type is None:
271282
return False
272283

284+
if py_version == "py2" and framework in PY2_RESTRICTED_EIA_FRAMEWORKS:
285+
raise ValueError(
286+
"{} is not supported with Amazon Elastic Inference in Python 2.".format(framework)
287+
)
288+
273289
if framework not in VALID_EIA_FRAMEWORKS:
274290
raise ValueError(
275291
"{} is not supported with Amazon Elastic Inference. Currently only "
276-
"Python-based TensorFlow and MXNet are supported.".format(framework)
292+
"Python-based TensorFlow, MXNet, PyTorch are supported.".format(framework)
277293
)
278294

279295
if optimized_families:

src/sagemaker/pytorch/README.rst

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ With PyTorch Estimators and Models, you can train and host PyTorch models on Ama
66

77
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``, ``1.4.0``.
88

9+
Supported versions of PyTorch for Elastic Inference: ``1.3.1``.
10+
911
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
1012

1113
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.

src/sagemaker/pytorch/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
136136
For example, 'ml.p2.xlarge'.
137137
accelerator_type (str): The Elastic Inference accelerator type to
138138
deploy to the instance for loading and making inferences to the
139-
model. Currently unsupported with PyTorch.
139+
model.
140140
141141
Returns:
142142
dict[str, str]: A container definition object usable with the
@@ -169,7 +169,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
169169
(cpu/gpu/family-specific optimized).
170170
accelerator_type (str): The Elastic Inference accelerator type to
171171
deploy to the instance for loading and making inferences to the
172-
model. Currently unsupported with PyTorch.
172+
model.
173173
174174
Returns:
175175
str: The appropriate image URI based on the given parameters.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# This file is intentionally left blank to invoke default model_fn and predict_fn
129 KB
Binary file not shown.

tests/integ/test_pytorch_train.py

+30
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
2828
MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py")
2929

30+
EIA_DIR = os.path.join(DATA_DIR, "pytorch_eia")
31+
EIA_MODEL = os.path.join(EIA_DIR, "model_mnist.tar.gz")
32+
EIA_SCRIPT = os.path.join(EIA_DIR, "empty_inference_script.py")
33+
3034

3135
@pytest.fixture(scope="module", name="pytorch_training_job")
3236
def fixture_training_job(sagemaker_session, pytorch_full_version, cpu_instance_type):
@@ -115,6 +119,32 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
115119
assert output.shape == (batch_size, 10)
116120

117121

122+
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch EIA does not support Python 2.")
123+
def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
124+
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
125+
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
126+
pytorch = PyTorchModel(
127+
model_data,
128+
"SageMakerRole",
129+
framework_version="1.3.1",
130+
entry_point=EIA_SCRIPT,
131+
sagemaker_session=sagemaker_session,
132+
)
133+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
134+
predictor = pytorch.deploy(
135+
initial_instance_count=1,
136+
instance_type=cpu_instance_type,
137+
accelerator_type="ml.eia2.medium",
138+
endpoint_name=endpoint_name,
139+
)
140+
141+
batch_size = 100
142+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
143+
output = predictor.predict(data)
144+
145+
assert output.shape == (batch_size, 10)
146+
147+
118148
def _upload_training_data(pytorch):
119149
return pytorch.sagemaker_session.upload_data(
120150
path=os.path.join(MNIST_DIR, "training"),

tests/unit/test_fw_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,37 @@ def test_mxnet_eia_images():
311311
)
312312

313313

314+
def test_pytorch_eia_images():
315+
image_uri = fw_utils.create_image_uri(
316+
"us-east-1",
317+
"pytorch-serving",
318+
"ml.c4.2xlarge",
319+
"1.3.1",
320+
"py3",
321+
accelerator_type="ml.eia1.large",
322+
)
323+
assert (
324+
image_uri
325+
== "{}.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-eia:1.3.1-cpu-py3".format(
326+
fw_utils.ASIMOV_PROD_ACCOUNT
327+
)
328+
)
329+
330+
331+
def test_pytorch_eia_py2_error():
332+
error_message = "pytorch-serving is not supported with Amazon Elastic Inference in Python 2."
333+
with pytest.raises(ValueError) as error:
334+
fw_utils.create_image_uri(
335+
"us-east-1",
336+
"pytorch-serving",
337+
"ml.c4.2xlarge",
338+
"1.3.1",
339+
"py2",
340+
accelerator_type="ml.eia1.large",
341+
)
342+
assert error_message in str(error)
343+
344+
314345
def test_create_image_uri_override_account():
315346
image_uri = fw_utils.create_image_uri(
316347
"us-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"

tests/unit/test_pytorch.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,19 @@ def test_non_mms_model(repack_model, sagemaker_session):
345345

346346
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
347347
def test_model_image_accelerator(sagemaker_session):
348-
model = PyTorchModel(
349-
MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session
348+
with pytest.raises(ValueError) as error:
349+
model = PyTorchModel(
350+
MODEL_DATA,
351+
role=ROLE,
352+
entry_point=SCRIPT_PATH,
353+
sagemaker_session=sagemaker_session,
354+
framework_version="1.3.1",
355+
py_version="py2",
356+
)
357+
model.deploy(1, CPU, accelerator_type=ACCELERATOR_TYPE)
358+
assert "pytorch-serving is not supported with Amazon Elastic Inference in Python 2." in str(
359+
error
350360
)
351-
with pytest.raises(ValueError):
352-
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
353361

354362

355363
def test_train_image_default(sagemaker_session):

0 commit comments

Comments
 (0)