Skip to content

Commit bc0f5eb

Browse files
authored
Merge pull request #1 from aws/pytorch
Add Pytorch estimator and model
2 parents 9362bba + ac28b7c commit bc0f5eb

File tree

14 files changed

+810
-2
lines changed

14 files changed

+810
-2
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def read(fname):
1111

1212

1313
setup(name="sagemaker",
14-
version="1.2.3",
14+
version="1.3.dev",
1515
description="Open source library for training and deploying models on Amazon SageMaker.",
1616
packages=find_packages('src'),
1717
package_dir={'': 'src'},

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def framework_name_from_image(image_name):
145145
else:
146146
# extract framework, python version and image tag
147147
# We must support both the legacy and current image name format.
148-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet):(.*?)-(.*?)-(py2|py3)$')
148+
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|pytorch):(.*?)-(.*?)-(py2|py3)$')
149149
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
150150
name_match = name_pattern.match(sagemaker_match.group(8))
151151
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))

src/sagemaker/pytorch/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.pytorch.estimator import PyTorch
14+
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor
15+
16+
__all__ = [PyTorch, PyTorchModel, PyTorchPredictor]

src/sagemaker/pytorch/defaults.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
PYTORCH_VERSION = '0.3'
14+
PYTHON_VERSION = 'py3'

src/sagemaker/pytorch/estimator.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.estimator import Framework
14+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
15+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
16+
from sagemaker.pytorch.model import PyTorchModel
17+
18+
19+
class PyTorch(Framework):
20+
"""Handle end-to-end training and deployment of custom PyTorch code."""
21+
22+
__framework_name__ = "pytorch"
23+
24+
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
25+
framework_version=PYTORCH_VERSION, **kwargs):
26+
"""
27+
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
28+
Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions
29+
defined in the supplied ``entry_point`` Python script.
30+
31+
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
32+
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
33+
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance
34+
that can be used to perform inference against the hosted model.
35+
36+
Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is
37+
available on the project home-page: https://github.com/aws/sagemaker-python-sdk
38+
39+
Args:
40+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
41+
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
42+
source_dir (str): Path (absolute or relative) to a directory with any other training
43+
source code dependencies aside from tne entry point file (default: None). Structure within this
44+
directory are preserved when training on Amazon SageMaker.
45+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
46+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
47+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
48+
to convert them before training.
49+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
50+
One of 'py2' or 'py3'.
51+
framework_version (str): PyTorch version you want to use for executing your model training code.
52+
List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators
53+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
54+
"""
55+
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
56+
self.py_version = py_version
57+
self.framework_version = framework_version
58+
59+
def train_image(self):
60+
"""Return the Docker image to use for training.
61+
62+
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
63+
find the image to use for model training.
64+
65+
Returns:
66+
str: The URI of the Docker image.
67+
"""
68+
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
69+
self.train_instance_type, framework_version=self.framework_version,
70+
py_version=self.py_version)
71+
72+
def create_model(self, model_server_workers=None):
73+
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
74+
75+
Args:
76+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
77+
If None, server will use one worker per vCPU.
78+
79+
Returns:
80+
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
81+
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
82+
"""
83+
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
84+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
85+
container_log_level=self.container_log_level, code_location=self.code_location,
86+
py_version=self.py_version, framework_version=self.framework_version,
87+
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
88+
89+
@classmethod
90+
def _prepare_init_params_from_job_description(cls, job_details):
91+
"""Convert the job description to init params that can be handled by the class constructor
92+
93+
Args:
94+
job_details: the returned job details from a describe_training_job API call.
95+
96+
Returns:
97+
dictionary: The transformed init_params
98+
99+
"""
100+
init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details)
101+
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
102+
103+
init_params['py_version'] = py_version
104+
init_params['framework_version'] = framework_version_from_tag(tag)
105+
106+
training_job_name = init_params['base_job_name']
107+
108+
if framework != cls.__framework_name__:
109+
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
110+
111+
return init_params

src/sagemaker/pytorch/model.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import sagemaker
14+
from sagemaker.fw_utils import create_image_uri
15+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
16+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
17+
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
18+
from sagemaker.utils import name_from_image
19+
20+
21+
class PyTorchPredictor(RealTimePredictor):
22+
"""A RealTimePredictor for inference against PyTorch Endpoints.
23+
24+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for PyTorch
25+
inference."""
26+
27+
def __init__(self, endpoint_name, sagemaker_session=None):
28+
"""Initialize an ``PyTorchPredictor``.
29+
30+
Args:
31+
endpoint_name (str): The name of the endpoint to perform inference on.
32+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
33+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
34+
using the default AWS configuration chain.
35+
"""
36+
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer)
37+
38+
39+
class PyTorchModel(FrameworkModel):
40+
"""An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
41+
42+
__framework_name__ = 'pytorch'
43+
44+
def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION,
45+
framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor,
46+
model_server_workers=None, **kwargs):
47+
"""Initialize an PyTorchModel.
48+
49+
Args:
50+
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
51+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
52+
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
53+
After the endpoint is created, the inference code might use the IAM role,
54+
if it needs to access an AWS resource.
55+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
56+
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
57+
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
58+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
59+
framework_version (str): PyTorch version you want to use for executing your model training code.
60+
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
61+
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
62+
invoking this function on the created endpoint name.
63+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
64+
If None, server will use one worker per vCPU.
65+
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
66+
"""
67+
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
68+
self.py_version = py_version
69+
self.framework_version = framework_version
70+
self.model_server_workers = model_server_workers
71+
72+
def prepare_container_def(self, instance_type):
73+
"""Return a container definition with framework configuration set in model environment variables.
74+
75+
Args:
76+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
77+
78+
Returns:
79+
dict[str, str]: A container definition object usable with the CreateModel API.
80+
"""
81+
deploy_image = self.image
82+
if not deploy_image:
83+
region_name = self.sagemaker_session.boto_session.region_name
84+
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
85+
self.framework_version, self.py_version)
86+
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
87+
self._upload_code(deploy_key_prefix)
88+
deploy_env = dict(self.env)
89+
deploy_env.update(self._framework_env_vars())
90+
91+
if self.model_server_workers:
92+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
93+
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def mxnet_version(request):
6666
return request.param
6767

6868

69+
@pytest.fixture(scope='module', params=["0.3", "0.3.1"])
70+
def pytorch_version(request):
71+
return request.param
72+
73+
6974
@pytest.fixture(scope='module', params=['1.4.1', '1.5.0', '1.6.0'])
7075
def tf_full_version(request):
7176
return request.param
@@ -74,3 +79,8 @@ def tf_full_version(request):
7479
@pytest.fixture(scope='module', params=['0.12.1', '1.0.0', '1.1.0'])
7580
def mxnet_full_version(request):
7681
return request.param
82+
83+
84+
@pytest.fixture(scope='module', params=["0.3.1"])
85+
def pytorch_full_version(request):
86+
return request.param
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TODO(nadiaya): Remove the arguments when they are no longer required
2+
def train(host_rank, master_addr, master_port):
3+
"""For use with integration tests expecting failures."""
4+
raise Exception('This failure is expected.')

0 commit comments

Comments
 (0)