Skip to content

Commit b3c8bb1

Browse files
verdimrcajaykarpurathewseyPanigrahiahsan-z-khan
authored
feature: processors that support multiple Python files, requirements.txt, and dependencies. (#2251)
Co-authored-by: Ajay Karpur <[email protected]> Co-authored-by: Alex Thewsey <[email protected]> Co-authored-by: Panigrahi <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent cb8bc65 commit b3c8bb1

31 files changed

+2102
-195
lines changed

src/sagemaker/huggingface/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515

1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
1717
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
18+
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401
+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to HuggingFace Processors which are used for Processing jobs.
14+
15+
These jobs let customers perform data pre-processing, post-processing, feature engineering,
16+
data validation, and model evaluation and interpretation on SageMaker.
17+
"""
18+
from __future__ import absolute_import
19+
20+
from sagemaker.processing import FrameworkProcessor
21+
from sagemaker.huggingface.estimator import HuggingFace
22+
23+
24+
class HuggingFaceProcessor(FrameworkProcessor):
25+
"""Handles Amazon SageMaker processing tasks for jobs using HuggingFace containers."""
26+
27+
estimator_cls = HuggingFace
28+
29+
def __init__(
30+
self,
31+
role,
32+
instance_count,
33+
instance_type,
34+
transformers_version=None,
35+
tensorflow_version=None,
36+
pytorch_version=None,
37+
py_version="py36",
38+
image_uri=None,
39+
command=None,
40+
volume_size_in_gb=30,
41+
volume_kms_key=None,
42+
output_kms_key=None,
43+
code_location=None,
44+
max_runtime_in_seconds=None,
45+
base_job_name=None,
46+
sagemaker_session=None,
47+
env=None,
48+
tags=None,
49+
network_config=None,
50+
):
51+
"""This processor executes a Python script in a HuggingFace execution environment.
52+
53+
Unless ``image_uri`` is specified, the environment is an Amazon-built Docker container
54+
that executes functions defined in the supplied ``code`` Python script.
55+
56+
The arguments have the same meaning as in ``FrameworkProcessor``, with the following
57+
exceptions.
58+
59+
Args:
60+
transformers_version (str): Transformers version you want to use for
61+
executing your model training code. Defaults to ``None``. Required unless
62+
``image_uri`` is provided. The current supported version is ``4.4.2``.
63+
tensorflow_version (str): TensorFlow version you want to use for
64+
executing your model training code. Defaults to ``None``. Required unless
65+
``pytorch_version`` is provided. The current supported version is ``1.6.0``.
66+
pytorch_version (str): PyTorch version you want to use for
67+
executing your model training code. Defaults to ``None``. Required unless
68+
``tensorflow_version`` is provided. The current supported version is ``2.4.1``.
69+
py_version (str): Python version you want to use for executing your model training
70+
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
71+
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
72+
the current supported version is ``py37``.
73+
74+
.. tip::
75+
76+
You can find additional parameters for initializing this class at
77+
:class:`~sagemaker.processing.FrameworkProcessor`.
78+
"""
79+
self.pytorch_version = pytorch_version
80+
self.tensorflow_version = tensorflow_version
81+
super().__init__(
82+
self.estimator_cls,
83+
transformers_version,
84+
role,
85+
instance_count,
86+
instance_type,
87+
py_version,
88+
image_uri,
89+
command,
90+
volume_size_in_gb,
91+
volume_kms_key,
92+
output_kms_key,
93+
code_location,
94+
max_runtime_in_seconds,
95+
base_job_name,
96+
sagemaker_session,
97+
env,
98+
tags,
99+
network_config,
100+
)
101+
102+
def _create_estimator(
103+
self,
104+
entry_point="",
105+
source_dir=None,
106+
dependencies=None,
107+
git_config=None,
108+
):
109+
"""Override default estimator factory function for HuggingFace's different parameters
110+
111+
HuggingFace estimators have 3 framework version parameters instead of one: The version for
112+
Transformers, PyTorch, and TensorFlow.
113+
"""
114+
return self.estimator_cls(
115+
transformers_version=self.framework_version,
116+
tensorflow_version=self.tensorflow_version,
117+
pytorch_version=self.pytorch_version,
118+
py_version=self.py_version,
119+
entry_point=entry_point,
120+
source_dir=source_dir,
121+
dependencies=dependencies,
122+
git_config=git_config,
123+
code_location=self.code_location,
124+
enable_network_isolation=False,
125+
image_uri=self.image_uri,
126+
role=self.role,
127+
instance_count=self.instance_count,
128+
instance_type=self.instance_type,
129+
sagemaker_session=self.sagemaker_session,
130+
debugger_hook_config=False,
131+
disable_profiler=True,
132+
)

src/sagemaker/local/local_session.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,30 @@ def invoke_endpoint(
478478

479479

480480
class LocalSession(Session):
481-
"""A LocalSession class definition."""
481+
"""A SageMaker ``Session`` class for Local Mode.
482482
483-
def __init__(self, boto_session=None, s3_endpoint_url=None):
483+
This class provides alternative Local Mode implementations for the functionality of
484+
:class:`~sagemaker.session.Session`.
485+
"""
486+
487+
def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False):
488+
"""Create a Local SageMaker Session.
489+
490+
Args:
491+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
492+
calls are delegated to (default: None). If not provided, one is created with
493+
default AWS configuration chain.
494+
s3_endpoint_url (str): Override the default endpoint URL for Amazon S3, if set
495+
(default: None).
496+
disable_local_code (bool): Set ``True`` to override the default AWS configuration
497+
chain to disable the ``local.local_code`` setting, which may not be supported for
498+
some SDK features (default: False).
499+
"""
484500
self.s3_endpoint_url = s3_endpoint_url
501+
# We use this local variable to avoid disrupting the __init__->_initialize API of the
502+
# parent class... But overwriting it after constructor won't do anything, so prefix _ to
503+
# discourage external use:
504+
self._disable_local_code = disable_local_code
485505

486506
super(LocalSession, self).__init__(boto_session)
487507

@@ -533,6 +553,8 @@ def _initialize(
533553
raise e
534554

535555
self.config = yaml.load(open(sagemaker_config_file, "r"))
556+
if self._disable_local_code and "local" in self.config:
557+
self.config["local"]["local_code"] = False
536558

537559
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
538560
"""A no-op method meant to override the sagemaker client.

src/sagemaker/mxnet/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using MXNet with Amazon SageMaker."""
1414
from __future__ import absolute_import # noqa: F401
1515

1616
from sagemaker.mxnet.estimator import MXNet # noqa: F401
1717
from sagemaker.mxnet.model import MXNetModel, MXNetPredictor # noqa: F401
18+
from sagemaker.mxnet.processing import MXNetProcessor # noqa: F401

src/sagemaker/mxnet/processing.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to MXNet Processors which are used for Processing jobs.
14+
15+
These jobs let customers perform data pre-processing, post-processing, feature engineering,
16+
data validation, and model evaluation and interpretation on SageMaker.
17+
"""
18+
from __future__ import absolute_import
19+
20+
from sagemaker.mxnet.estimator import MXNet
21+
from sagemaker.processing import FrameworkProcessor
22+
23+
24+
class MXNetProcessor(FrameworkProcessor):
25+
"""Handles Amazon SageMaker processing tasks for jobs using MXNet containers."""
26+
27+
estimator_cls = MXNet
28+
29+
def __init__(
30+
self,
31+
framework_version, # New arg
32+
role,
33+
instance_count,
34+
instance_type,
35+
py_version="py3", # New kwarg
36+
image_uri=None,
37+
command=None,
38+
volume_size_in_gb=30,
39+
volume_kms_key=None,
40+
output_kms_key=None,
41+
code_location=None, # New arg
42+
max_runtime_in_seconds=None,
43+
base_job_name=None,
44+
sagemaker_session=None,
45+
env=None,
46+
tags=None,
47+
network_config=None,
48+
):
49+
"""This processor executes a Python script in a managed MXNet execution environment.
50+
51+
Unless ``image_uri`` is specified, the MXNet environment is an
52+
Amazon-built Docker container that executes functions defined in the supplied
53+
``code`` Python script.
54+
55+
The arguments have the exact same meaning as in ``FrameworkProcessor``.
56+
57+
.. tip::
58+
59+
You can find additional parameters for initializing this class at
60+
:class:`~sagemaker.processing.FrameworkProcessor`.
61+
"""
62+
super().__init__(
63+
self.estimator_cls,
64+
framework_version,
65+
role,
66+
instance_count,
67+
instance_type,
68+
py_version,
69+
image_uri,
70+
command,
71+
volume_size_in_gb,
72+
volume_kms_key,
73+
output_kms_key,
74+
code_location,
75+
max_runtime_in_seconds,
76+
base_job_name,
77+
sagemaker_session,
78+
env,
79+
tags,
80+
network_config,
81+
)

0 commit comments

Comments
 (0)