Skip to content

Commit f14d86c

Browse files
author
Dan
authored
feature: Add support for PyTorch 1.2.0 (#1091)
1 parent 0e97997 commit f14d86c

File tree

8 files changed

+126
-9
lines changed

8 files changed

+126
-9
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ PyTorch SageMaker Estimators
222222

223223
With PyTorch SageMaker Estimators, you can train and host PyTorch models on Amazon SageMaker.
224224

225-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
225+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.
226226

227227
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
228228

src/sagemaker/fw_utils.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
"tensorflow-serving-eia": "tensorflow-inference-eia",
6868
"mxnet": "mxnet-training",
6969
"mxnet-serving": "mxnet-inference",
70+
"pytorch": "pytorch-training",
71+
"pytorch-serving": "pytorch-inference",
7072
"mxnet-serving-eia": "mxnet-inference-eia",
7173
}
7274

@@ -76,6 +78,8 @@
7678
"tensorflow-serving-eia": [1, 14, 0],
7779
"mxnet": [1, 4, 1],
7880
"mxnet-serving": [1, 4, 1],
81+
"pytorch": [1, 2, 0],
82+
"pytorch-serving": [1, 2, 0],
7983
"mxnet-serving-eia": [1, 4, 1],
8084
}
8185

@@ -119,10 +123,15 @@ def _using_merged_images(region, framework, py_version, framework_version):
119123
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
120124
is_py3 = py_version == "py3" or py_version is None
121125
is_merged_versions = _is_merged_versions(framework, framework_version)
126+
122127
return (
123128
((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION)
124129
and is_merged_versions
125-
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
130+
and (
131+
is_py3
132+
or _is_tf_14_or_later(framework, framework_version)
133+
or _is_pt_12_or_later(framework, framework_version)
134+
)
126135
)
127136

128137

@@ -140,6 +149,19 @@ def _is_tf_14_or_later(framework, framework_version):
140149
)
141150

142151

152+
def _is_pt_12_or_later(framework, framework_version):
153+
"""
154+
Args:
155+
framework: Name of the frameowork
156+
framework_version: framework version
157+
"""
158+
# Asimov team now owns PyTorch 1.2.0 py2 and py3
159+
asimov_lowest_pt = [1, 2, 0]
160+
version = [int(s) for s in framework_version.split(".")]
161+
is_pytorch = framework in ("pytorch", "pytorch-serving")
162+
return is_pytorch and version >= asimov_lowest_pt[0 : len(version)]
163+
164+
143165
def _registry_id(region, framework, py_version, account, framework_version):
144166
"""
145167
Args:

src/sagemaker/pytorch/README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ SageMaker PyTorch Estimators and Models
44

55
With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker.
66

7-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
7+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.
88

99
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
1010

src/sagemaker/pytorch/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class PyTorch(Framework):
3434

3535
__framework_name__ = "pytorch"
3636

37-
LATEST_VERSION = "1.1"
37+
LATEST_VERSION = "1.2.0"
3838
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""
3939

4040
def __init__(

src/sagemaker/pytorch/model.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
import pkg_resources
1718

1819
import sagemaker
1920
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
@@ -53,6 +54,7 @@ class PyTorchModel(FrameworkModel):
5354
"""
5455

5556
__framework_name__ = "pytorch"
57+
_LOWEST_MMS_VERSION = "1.2"
5658

5759
def __init__(
5860
self,
@@ -122,22 +124,33 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
122124
dict[str, str]: A container definition object usable with the
123125
CreateModel API.
124126
"""
127+
lowest_mms_version = pkg_resources.parse_version(self._LOWEST_MMS_VERSION)
128+
framework_version = pkg_resources.parse_version(self.framework_version)
129+
is_mms_version = framework_version >= lowest_mms_version
130+
125131
deploy_image = self.image
126132
if not deploy_image:
127133
region_name = self.sagemaker_session.boto_session.region_name
134+
135+
framework_name = self.__framework_name__
136+
if is_mms_version:
137+
framework_name += "-serving"
138+
128139
deploy_image = create_image_uri(
129140
region_name,
130-
self.__framework_name__,
141+
framework_name,
131142
instance_type,
132143
self.framework_version,
133144
self.py_version,
134145
accelerator_type=accelerator_type,
135146
)
136147
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
137-
self._upload_code(deploy_key_prefix)
148+
self._upload_code(deploy_key_prefix, repack=is_mms_version)
138149
deploy_env = dict(self.env)
139150
deploy_env.update(self._framework_env_vars())
140151

141152
if self.model_server_workers:
142153
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
143-
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
154+
return sagemaker.container_def(
155+
deploy_image, self.repacked_model_data or self.model_data, deploy_env
156+
)

tests/integ/test_pytorch_train.py

+25
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
5454
assert output.shape == (batch_size, 10)
5555

5656

57+
@pytest.mark.local_mode
58+
def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
59+
pytorch = PyTorch(
60+
entry_point=MNIST_SCRIPT,
61+
role="SageMakerRole",
62+
framework_version=pytorch_full_version,
63+
py_version="py3",
64+
train_instance_count=1,
65+
train_instance_type="local",
66+
sagemaker_session=sagemaker_local_session,
67+
)
68+
69+
pytorch.fit({"training": "file://" + os.path.join(MNIST_DIR, "training")})
70+
71+
predictor = pytorch.deploy(1, "local")
72+
try:
73+
batch_size = 100
74+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
75+
output = predictor.predict(data)
76+
77+
assert output.shape == (batch_size, 10)
78+
finally:
79+
predictor.delete_endpoint()
80+
81+
5782
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
5883
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
5984

tests/unit/test_fw_utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,28 @@ def test_create_image_uri_merged_gov_regions():
313313
)
314314

315315

316+
def test_create_image_uri_merged_pytorch():
317+
318+
image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.2", "py2")
319+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.2-gpu-py2"
320+
321+
image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.1", "py2")
322+
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:1.1-gpu-py2"
323+
324+
image_uri = fw_utils.create_image_uri(
325+
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.2", "py2"
326+
)
327+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.2-cpu-py2"
328+
329+
image_uri = fw_utils.create_image_uri(
330+
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.1", "py2"
331+
)
332+
assert (
333+
image_uri
334+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch-serving:1.1-cpu-py2"
335+
)
336+
337+
316338
def test_create_image_uri_accelerator_tf():
317339
image_uri = fw_utils.create_image_uri(
318340
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"

tests/unit/test_pytorch.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import os
1818
import pytest
1919
import sys
20-
from mock import MagicMock, Mock
21-
from mock import patch
20+
from mock import ANY, MagicMock, Mock, patch
2221

2322
from sagemaker.pytorch import defaults
2423
from sagemaker.pytorch import PyTorch
@@ -296,6 +295,42 @@ def test_model(sagemaker_session):
296295
assert isinstance(predictor, PyTorchPredictor)
297296

298297

298+
@patch("sagemaker.utils.create_tar_file", MagicMock())
299+
@patch("sagemaker.utils.repack_model")
300+
def test_mms_model(repack_model, sagemaker_session):
301+
PyTorchModel(
302+
MODEL_DATA,
303+
role=ROLE,
304+
entry_point=SCRIPT_PATH,
305+
sagemaker_session=sagemaker_session,
306+
framework_version="1.2",
307+
).deploy(1, GPU)
308+
309+
repack_model.assert_called_with(
310+
dependencies=[],
311+
inference_script=SCRIPT_PATH,
312+
kms_key=None,
313+
model_uri="s3://some/data.tar.gz",
314+
repacked_model_uri=ANY,
315+
sagemaker_session=sagemaker_session,
316+
source_directory=None,
317+
)
318+
319+
320+
@patch("sagemaker.utils.create_tar_file", MagicMock())
321+
@patch("sagemaker.utils.repack_model")
322+
def test_non_mms_model(repack_model, sagemaker_session):
323+
PyTorchModel(
324+
MODEL_DATA,
325+
role=ROLE,
326+
entry_point=SCRIPT_PATH,
327+
sagemaker_session=sagemaker_session,
328+
framework_version="1.1",
329+
).deploy(1, GPU)
330+
331+
repack_model.assert_not_called()
332+
333+
299334
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
300335
def test_model_image_accelerator(sagemaker_session):
301336
model = PyTorchModel(

0 commit comments

Comments
 (0)