Skip to content

Commit 4681b90

Browse files
authored
Merge branch 'zwei' into require-framework-version-xgboost
2 parents 7bc6392 + 09336f7 commit 4681b90

14 files changed

+265
-255
lines changed

doc/frameworks/mxnet/using_mxnet.rst

+7-6
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ The following code sample shows how you train a custom MXNet script "train.py".
183183
mxnet_estimator = MXNet('train.py',
184184
train_instance_type='ml.p2.xlarge',
185185
train_instance_count=1,
186-
framework_version='1.3.0',
186+
framework_version='1.6.0',
187+
py_version='py3',
187188
hyperparameters={'batch-size': 100,
188189
'epochs': 10,
189190
'learning-rate': 0.1})
@@ -230,10 +231,10 @@ If you use the ``MXNet`` estimator to train the model, you can call ``deploy`` t
230231
231232
# Train my estimator
232233
mxnet_estimator = MXNet('train.py',
233-
train_instance_type='ml.p2.xlarge',
234-
train_instance_count=1,
234+
framework_version='1.6.0',
235235
py_version='py3',
236-
framework_version='1.6.0')
236+
train_instance_type='ml.p2.xlarge',
237+
train_instance_count=1)
237238
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
238239
239240
# Deploy my estimator to an Amazon SageMaker Endpoint and get a Predictor
@@ -247,8 +248,8 @@ If using a pretrained model, create an ``MXNetModel`` object, and then call ``de
247248
mxnet_model = MXNetModel(model_data='s3://my_bucket/pretrained_model/model.tar.gz',
248249
role=role,
249250
entry_point='inference.py',
250-
py_version='py3',
251-
framework_version='1.6.0')
251+
framework_version='1.6.0',
252+
py_version='py3')
252253
predictor = mxnet_model.deploy(instance_type='ml.m4.xlarge',
253254
initial_instance_count=1)
254255

src/sagemaker/cli/mxnet.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.common import HostCommand, TrainCommand
17+
from sagemaker.mxnet import defaults
1718

1819

1920
def train(args):
@@ -40,13 +41,14 @@ def create_estimator(self):
4041
from sagemaker.mxnet.estimator import MXNet
4142

4243
return MXNet(
43-
self.script,
44+
entry_point=self.script,
45+
framework_version=defaults.MXNET_VERSION,
46+
py_version=self.python,
4447
role=self.role_name,
4548
base_job_name=self.job_name,
4649
train_instance_count=self.instance_count,
4750
train_instance_type=self.instance_type,
4851
hyperparameters=self.hyperparameters,
49-
py_version=self.python,
5052
)
5153

5254

@@ -64,6 +66,7 @@ def create_model(self, model_url):
6466
model_data=model_url,
6567
role=self.role_name,
6668
entry_point=self.script,
69+
framework_version=defaults.MXNET_VERSION,
6770
py_version=self.python,
6871
name=self.endpoint_name,
6972
env=self.environment,

src/sagemaker/fw_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,24 @@ def _region_supports_debugger(region_name):
681681
682682
"""
683683
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
684+
685+
686+
def validate_version_or_image_args(framework_version, py_version, image_name):
687+
"""Checks if version or image arguments are specified.
688+
689+
Validates framework and model arguments to enforce version or image specification.
690+
691+
Args:
692+
framework_version (str): The version of the framework.
693+
py_version (str): The version of Python.
694+
image_name (str): The URI of the image.
695+
696+
Raises:
697+
ValueError: if `image_name` is None and either `framework_version` or `py_version` is
698+
None.
699+
"""
700+
if (framework_version is None or py_version is None) and image_name is None:
701+
raise ValueError(
702+
"framework_version or py_version was None, yet image_name was also None. "
703+
"Either specify both framework_version and py_version, or specify image_name."
704+
)

src/sagemaker/mxnet/estimator.py

+41-33
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
22+
is_version_equal_or_higher,
2323
python_deprecation_warning,
2424
parameter_v2_rename_warning,
25-
is_version_equal_or_higher,
25+
validate_version_or_image_args,
2626
warn_if_parameter_server_with_multi_gpu,
2727
)
2828
from sagemaker.mxnet import defaults
@@ -43,10 +43,10 @@ class MXNet(Framework):
4343
def __init__(
4444
self,
4545
entry_point,
46+
framework_version=None,
47+
py_version=None,
4648
source_dir=None,
4749
hyperparameters=None,
48-
py_version="py2",
49-
framework_version=None,
5050
image_name=None,
5151
distributions=None,
5252
**kwargs
@@ -73,6 +73,13 @@ def __init__(
7373
file which should be executed as the entry point to training.
7474
If ``source_dir`` is specified, then ``entry_point``
7575
must point to a file located at the root of ``source_dir``.
76+
framework_version (str): MXNet version you want to use for executing
77+
your model training code. Defaults to `None`. Required unless
78+
``image_name`` is provided. List of supported versions.
79+
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
80+
py_version (str): Python version you want to use for executing your
81+
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
82+
unless ``image_name`` is provided.
7683
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7784
with any other training source code dependencies aside from the entry
7885
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +91,6 @@ def __init__(
8491
SageMaker. For convenience, this accepts other types for keys
8592
and values, but ``str()`` will be called to convert them before
8693
training.
87-
py_version (str): Python version you want to use for executing your
88-
model training code (default: 'py2'). One of 'py2' or 'py3'.
89-
framework_version (str): MXNet version you want to use for executing
90-
your model training code. List of supported versions
91-
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92-
If not specified, this will default to 1.2.1.
9394
image_name (str): If specified, the estimator will use this image for training and
9495
hosting, instead of selecting the appropriate SageMaker official image based on
9596
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -98,6 +99,9 @@ def __init__(
9899
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
99100
* ``custom-image:latest``
100101
102+
If ``framework_version`` or ``py_version`` are ``None``, then
103+
``image_name`` is required. If also ``None``, then a ``ValueError``
104+
will be raised.
101105
distributions (dict): A dictionary with information on how to run distributed
102106
training (default: None). To have parameter servers launched for training,
103107
set this value to be ``{'parameter_server': {'enabled': True}}``.
@@ -110,34 +114,32 @@ def __init__(
110114
:class:`~sagemaker.estimator.Framework` and
111115
:class:`~sagemaker.estimator.EstimatorBase`.
112116
"""
113-
if framework_version is None:
117+
validate_version_or_image_args(framework_version, py_version, image_name)
118+
if py_version and py_version == "py2":
114119
logger.warning(
115-
empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)
120+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
116121
)
117-
self.framework_version = framework_version or defaults.MXNET_VERSION
122+
self.framework_version = framework_version
123+
self.py_version = py_version
118124

119125
if "enable_sagemaker_metrics" not in kwargs:
120126
# enable sagemaker metrics for MXNet v1.6 or greater:
121-
if is_version_equal_or_higher([1, 6], self.framework_version):
127+
if self.framework_version and is_version_equal_or_higher(
128+
[1, 6], self.framework_version
129+
):
122130
kwargs["enable_sagemaker_metrics"] = True
123131

124132
super(MXNet, self).__init__(
125133
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
126134
)
127135

128-
if py_version == "py2":
129-
logger.warning(
130-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
131-
)
132-
133136
if distributions is not None:
134137
logger.warning(parameter_v2_rename_warning("distributions", "distribution"))
135138
train_instance_type = kwargs.get("train_instance_type")
136139
warn_if_parameter_server_with_multi_gpu(
137140
training_instance_type=train_instance_type, distributions=distributions
138141
)
139142

140-
self.py_version = py_version
141143
self._configure_distribution(distributions)
142144

143145
def _configure_distribution(self, distributions):
@@ -148,7 +150,10 @@ def _configure_distribution(self, distributions):
148150
if distributions is None:
149151
return
150152

151-
if self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION:
153+
if (
154+
self.framework_version
155+
and self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION
156+
):
152157
raise ValueError(
153158
"The distributions option is valid for only versions {} and higher".format(
154159
".".join(self._LOWEST_SCRIPT_MODE_VERSION)
@@ -221,12 +226,12 @@ def create_model(
221226
self.model_data,
222227
role or self.role,
223228
entry_point or self.entry_point,
229+
framework_version=self.framework_version,
230+
py_version=self.py_version,
224231
source_dir=(source_dir or self._model_source_dir()),
225232
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
226233
container_log_level=self.container_log_level,
227234
code_location=self.code_location,
228-
py_version=self.py_version,
229-
framework_version=self.framework_version,
230235
model_server_workers=model_server_workers,
231236
sagemaker_session=self.sagemaker_session,
232237
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -254,22 +259,25 @@ class constructor
254259
image_name = init_params.pop("image")
255260
framework, py_version, tag, _ = framework_name_from_image(image_name)
256261

262+
# We switched image tagging scheme from regular image version (e.g. '1.0') to more
263+
# expressive containing framework version, device type and python version
264+
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
265+
# '0.12' framework version otherwise extract framework version from the tag itself.
266+
if tag is None:
267+
framework_version = None
268+
elif tag == "1.0":
269+
framework_version = "0.12"
270+
else:
271+
framework_version = framework_version_from_tag(tag)
272+
init_params["framework_version"] = framework_version
273+
init_params["py_version"] = py_version
274+
257275
if not framework:
258276
# If we were unable to parse the framework name from the image it is not one of our
259277
# officially supported images, in this case just add the image to the init params.
260278
init_params["image_name"] = image_name
261279
return init_params
262280

263-
init_params["py_version"] = py_version
264-
265-
# We switched image tagging scheme from regular image version (e.g. '1.0') to more
266-
# expressive containing framework version, device type and python version
267-
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
268-
# '0.12' framework version otherwise extract framework version from the tag itself.
269-
init_params["framework_version"] = (
270-
"0.12" if tag == "1.0" else framework_version_from_tag(tag)
271-
)
272-
273281
training_job_name = init_params["base_job_name"]
274282

275283
if framework != cls.__framework_name__:

src/sagemaker/mxnet/model.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
create_image_uri,
2323
model_code_key_prefix,
2424
python_deprecation_warning,
25-
empty_framework_version_warning,
25+
validate_version_or_image_args,
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.mxnet import defaults
@@ -65,9 +65,9 @@ def __init__(
6565
model_data,
6666
role,
6767
entry_point,
68-
image=None,
69-
py_version="py2",
7068
framework_version=None,
69+
py_version=None,
70+
image=None,
7171
predictor_cls=MXNetPredictor,
7272
model_server_workers=None,
7373
**kwargs
@@ -86,12 +86,18 @@ def __init__(
8686
file which should be executed as the entry point to model
8787
hosting. If ``source_dir`` is specified, then ``entry_point``
8888
must point to a file located at the root of ``source_dir``.
89+
framework_version (str): MXNet version you want to use for executing
90+
your model training code. Defaults to ``None``. Required unless
91+
``image`` is provided.
92+
py_version (str): Python version you want to use for executing your
93+
model training code. Defaults to ``None``. Required unless
94+
``image`` is provided.
8995
image (str): A Docker image URI (default: None). If not specified, a
9096
default image for MXNet will be used.
91-
py_version (str): Python version you want to use for executing your
92-
model training code (default: 'py2').
93-
framework_version (str): MXNet version you want to use for executing
94-
your model training code.
97+
98+
If ``framework_version`` or ``py_version`` are ``None``, then
99+
``image`` is required. If also ``None``, then a ``ValueError``
100+
will be raised.
95101
predictor_cls (callable[str, sagemaker.session.Session]): A function
96102
to call to create a predictor with an endpoint name and
97103
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -108,22 +114,18 @@ def __init__(
108114
:class:`~sagemaker.model.FrameworkModel` and
109115
:class:`~sagemaker.model.Model`.
110116
"""
111-
super(MXNetModel, self).__init__(
112-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
113-
)
114-
115-
if py_version == "py2":
117+
validate_version_or_image_args(framework_version, py_version, image)
118+
if py_version and py_version == "py2":
116119
logger.warning(
117120
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118121
)
122+
self.framework_version = framework_version
123+
self.py_version = py_version
119124

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
123-
)
125+
super(MXNetModel, self).__init__(
126+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
127+
)
124128

125-
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.MXNET_VERSION
127129
self.model_server_workers = model_server_workers
128130

129131
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def mxnet_version(request):
163163
return request.param
164164

165165

166+
@pytest.fixture(scope="module", params=["py2", "py3"])
167+
def mxnet_py_version(request):
168+
return request.param
169+
170+
166171
@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"])
167172
def pytorch_version(request):
168173
return request.param

tests/integ/test_local_mode.py

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _create_model(output_path):
6666
train_instance_type="local",
6767
output_path=output_path,
6868
framework_version=mxnet_full_version,
69+
py_version=PYTHON_VERSION,
6970
sagemaker_session=sagemaker_local_session,
7071
)
7172

@@ -188,6 +189,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
188189
train_instance_count=1,
189190
train_instance_type="local",
190191
framework_version=mxnet_full_version,
192+
py_version=PYTHON_VERSION,
191193
sagemaker_session=LocalNoS3Session(),
192194
)
193195

@@ -242,6 +244,7 @@ def test_local_transform_mxnet(
242244
train_instance_count=1,
243245
train_instance_type="local",
244246
framework_version=mxnet_full_version,
247+
py_version=PYTHON_VERSION,
245248
sagemaker_session=sagemaker_local_session,
246249
)
247250

tests/integ/test_neo_mxnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_inferentia_deploy_model(
131131
role,
132132
entry_point=script_path,
133133
framework_version=INF_MXNET_VERSION,
134+
py_version=PYTHON_VERSION,
134135
sagemaker_session=sagemaker_session,
135136
)
136137

tests/integ/test_transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type):
5151
train_instance_type=cpu_instance_type,
5252
sagemaker_session=sagemaker_session,
5353
framework_version=mxnet_full_version,
54+
py_version=PYTHON_VERSION,
5455
)
5556

5657
train_input = mx.sagemaker_session.upload_data(

0 commit comments

Comments
 (0)