Skip to content

Commit 032789c

Browse files
committed
breaking: require framework_version, py_version for mxnet
* framework_version, py_version required for framework MXNet * framework_version, py_version required for framework MXNetModel * image_name required if either framework_version or py_version None * re-order of non-default args, convention to follow entry_point * unit and integ testing updates * doc updates * ignore coverage results for py27 env due to v2 migration scripts
1 parent 97cd594 commit 032789c

15 files changed

+264
-255
lines changed

doc/frameworks/mxnet/using_mxnet.rst

+7-6
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ The following code sample shows how you train a custom MXNet script "train.py".
345345
mxnet_estimator = MXNet('train.py',
346346
train_instance_type='ml.p2.xlarge',
347347
train_instance_count=1,
348-
framework_version='1.3.0',
348+
framework_version='1.6.0',
349+
py_version='py3',
349350
hyperparameters={'batch-size': 100,
350351
'epochs': 10,
351352
'learning-rate': 0.1})
@@ -392,10 +393,10 @@ If you use the ``MXNet`` estimator to train the model, you can call ``deploy`` t
392393
393394
# Train my estimator
394395
mxnet_estimator = MXNet('train.py',
395-
train_instance_type='ml.p2.xlarge',
396-
train_instance_count=1,
396+
framework_version='1.6.0',
397397
py_version='py3',
398-
framework_version='1.6.0')
398+
train_instance_type='ml.p2.xlarge',
399+
train_instance_count=1)
399400
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
400401
401402
# Deploy my estimator to an Amazon SageMaker Endpoint and get a Predictor
@@ -409,8 +410,8 @@ If using a pretrained model, create an ``MXNetModel`` object, and then call ``de
409410
mxnet_model = MXNetModel(model_data='s3://my_bucket/pretrained_model/model.tar.gz',
410411
role=role,
411412
entry_point='inference.py',
412-
py_version='py3',
413-
framework_version='1.6.0')
413+
framework_version='1.6.0',
414+
py_version='py3')
414415
predictor = mxnet_model.deploy(instance_type='ml.m4.xlarge',
415416
initial_instance_count=1)
416417

src/sagemaker/cli/mxnet.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.common import HostCommand, TrainCommand
17+
from sagemaker.mxnet import defaults
1718

1819

1920
def train(args):
@@ -40,13 +41,14 @@ def create_estimator(self):
4041
from sagemaker.mxnet.estimator import MXNet
4142

4243
return MXNet(
43-
self.script,
44+
entry_point=self.script,
45+
framework_version=defaults.MXNET_VERSION,
46+
py_version=self.python,
4447
role=self.role_name,
4548
base_job_name=self.job_name,
4649
train_instance_count=self.instance_count,
4750
train_instance_type=self.instance_type,
4851
hyperparameters=self.hyperparameters,
49-
py_version=self.python,
5052
)
5153

5254

@@ -64,6 +66,7 @@ def create_model(self, model_url):
6466
model_data=model_url,
6567
role=self.role_name,
6668
entry_point=self.script,
69+
framework_version=defaults.MXNET_VERSION,
6770
py_version=self.python,
6871
name=self.endpoint_name,
6972
env=self.environment,

src/sagemaker/fw_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,22 @@ def _region_supports_debugger(region_name):
681681
682682
"""
683683
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
684+
685+
686+
def validate_version_or_image_args(framework_version, py_version, image_name):
687+
"""Checks if version or image arguments are specified.
688+
689+
Used to validate framework and model arguments to enforce version or image specification.
690+
Raises ValueError if version or image arguments are not specified.
691+
692+
Args:
693+
framework_version (str): the version of the framework
694+
py_version (str): the version of python
695+
image_name (str): the uri of the image
696+
"""
697+
if (framework_version is None or py_version is None) and image_name is None:
698+
raise ValueError(
699+
"framework_version or py_version was None, yet image_name was also None. "
700+
"Either specify both framework_version and py_version, or specify image_name."
701+
)
702+
return True

src/sagemaker/mxnet/estimator.py

+39-33
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
22+
is_version_equal_or_higher,
2323
python_deprecation_warning,
2424
parameter_v2_rename_warning,
25-
is_version_equal_or_higher,
25+
validate_version_or_image_args,
2626
warn_if_parameter_server_with_multi_gpu,
2727
)
2828
from sagemaker.mxnet import defaults
@@ -43,10 +43,10 @@ class MXNet(Framework):
4343
def __init__(
4444
self,
4545
entry_point,
46+
framework_version=None,
47+
py_version=None,
4648
source_dir=None,
4749
hyperparameters=None,
48-
py_version="py2",
49-
framework_version=None,
5050
image_name=None,
5151
distributions=None,
5252
**kwargs
@@ -73,6 +73,11 @@ def __init__(
7373
file which should be executed as the entry point to training.
7474
If ``source_dir`` is specified, then ``entry_point``
7575
must point to a file located at the root of ``source_dir``.
76+
framework_version (str): MXNet version you want to use for executing
77+
your model training code. List of supported versions. Defaults to ``None``.
78+
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
79+
py_version (str): Python version you want to use for executing your
80+
model training code. One of 'py2' or 'py3'. Defaults to ``None``.
7681
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7782
with any other training source code dependencies aside from the entry
7883
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +89,6 @@ def __init__(
8489
SageMaker. For convenience, this accepts other types for keys
8590
and values, but ``str()`` will be called to convert them before
8691
training.
87-
py_version (str): Python version you want to use for executing your
88-
model training code (default: 'py2'). One of 'py2' or 'py3'.
89-
framework_version (str): MXNet version you want to use for executing
90-
your model training code. List of supported versions
91-
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92-
If not specified, this will default to 1.2.1.
9392
image_name (str): If specified, the estimator will use this image for training and
9493
hosting, instead of selecting the appropriate SageMaker official image based on
9594
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -98,6 +97,9 @@ def __init__(
9897
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
9998
* ``custom-image:latest``
10099
100+
If ``framework_version`` or ``py_version`` are ``None``, then
101+
``image_name`` is required. If also ``None``, then a ``ValueError``
102+
will be raised.
101103
distributions (dict): A dictionary with information on how to run distributed
102104
training (default: None). To have parameter servers launched for training,
103105
set this value to be ``{'parameter_server': {'enabled': True}}``.
@@ -110,34 +112,32 @@ def __init__(
110112
:class:`~sagemaker.estimator.Framework` and
111113
:class:`~sagemaker.estimator.EstimatorBase`.
112114
"""
113-
if framework_version is None:
115+
validate_version_or_image_args(framework_version, py_version, image_name)
116+
if py_version and py_version == "py2":
114117
logger.warning(
115-
empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)
118+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
116119
)
117-
self.framework_version = framework_version or defaults.MXNET_VERSION
120+
self.framework_version = framework_version
121+
self.py_version = py_version
118122

119123
if "enable_sagemaker_metrics" not in kwargs:
120124
# enable sagemaker metrics for MXNet v1.6 or greater:
121-
if is_version_equal_or_higher([1, 6], self.framework_version):
125+
if self.framework_version and is_version_equal_or_higher(
126+
[1, 6], self.framework_version
127+
):
122128
kwargs["enable_sagemaker_metrics"] = True
123129

124130
super(MXNet, self).__init__(
125131
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
126132
)
127133

128-
if py_version == "py2":
129-
logger.warning(
130-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
131-
)
132-
133134
if distributions is not None:
134135
logger.warning(parameter_v2_rename_warning("distributions", "distribution"))
135136
train_instance_type = kwargs.get("train_instance_type")
136137
warn_if_parameter_server_with_multi_gpu(
137138
training_instance_type=train_instance_type, distributions=distributions
138139
)
139140

140-
self.py_version = py_version
141141
self._configure_distribution(distributions)
142142

143143
def _configure_distribution(self, distributions):
@@ -148,7 +148,10 @@ def _configure_distribution(self, distributions):
148148
if distributions is None:
149149
return
150150

151-
if self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION:
151+
if (
152+
self.framework_version
153+
and self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION
154+
):
152155
raise ValueError(
153156
"The distributions option is valid for only versions {} and higher".format(
154157
".".join(self._LOWEST_SCRIPT_MODE_VERSION)
@@ -221,12 +224,12 @@ def create_model(
221224
self.model_data,
222225
role or self.role,
223226
entry_point or self.entry_point,
227+
framework_version=self.framework_version,
228+
py_version=self.py_version,
224229
source_dir=(source_dir or self._model_source_dir()),
225230
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
226231
container_log_level=self.container_log_level,
227232
code_location=self.code_location,
228-
py_version=self.py_version,
229-
framework_version=self.framework_version,
230233
model_server_workers=model_server_workers,
231234
sagemaker_session=self.sagemaker_session,
232235
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -254,22 +257,25 @@ class constructor
254257
image_name = init_params.pop("image")
255258
framework, py_version, tag, _ = framework_name_from_image(image_name)
256259

260+
# We switched image tagging scheme from regular image version (e.g. '1.0') to more
261+
# expressive containing framework version, device type and python version
262+
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
263+
# '0.12' framework version otherwise extract framework version from the tag itself.
264+
if tag is None:
265+
framework_version = None
266+
elif tag == "1.0":
267+
framework_version = "0.12"
268+
else:
269+
framework_version = framework_version_from_tag(tag)
270+
init_params["framework_version"] = framework_version
271+
init_params["py_version"] = py_version
272+
257273
if not framework:
258274
# If we were unable to parse the framework name from the image it is not one of our
259275
# officially supported images, in this case just add the image to the init params.
260276
init_params["image_name"] = image_name
261277
return init_params
262278

263-
init_params["py_version"] = py_version
264-
265-
# We switched image tagging scheme from regular image version (e.g. '1.0') to more
266-
# expressive containing framework version, device type and python version
267-
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
268-
# '0.12' framework version otherwise extract framework version from the tag itself.
269-
init_params["framework_version"] = (
270-
"0.12" if tag == "1.0" else framework_version_from_tag(tag)
271-
)
272-
273279
training_job_name = init_params["base_job_name"]
274280

275281
if framework != cls.__framework_name__:

src/sagemaker/mxnet/model.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
create_image_uri,
2323
model_code_key_prefix,
2424
python_deprecation_warning,
25-
empty_framework_version_warning,
25+
validate_version_or_image_args,
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.mxnet import defaults
@@ -65,9 +65,9 @@ def __init__(
6565
model_data,
6666
role,
6767
entry_point,
68-
image=None,
69-
py_version="py2",
7068
framework_version=None,
69+
py_version=None,
70+
image=None,
7171
predictor_cls=MXNetPredictor,
7272
model_server_workers=None,
7373
**kwargs
@@ -86,12 +86,16 @@ def __init__(
8686
file which should be executed as the entry point to model
8787
hosting. If ``source_dir`` is specified, then ``entry_point``
8888
must point to a file located at the root of ``source_dir``.
89+
framework_version (str): MXNet version you want to use for executing
90+
your model training code. Defaults to ``None``.
91+
py_version (str): Python version you want to use for executing your
92+
model training code. Defaults to ``None``.
8993
image (str): A Docker image URI (default: None). If not specified, a
9094
default image for MXNet will be used.
91-
py_version (str): Python version you want to use for executing your
92-
model training code (default: 'py2').
93-
framework_version (str): MXNet version you want to use for executing
94-
your model training code.
95+
96+
If ``framework_version`` or ``py_version`` are ``None``, then
97+
``image_name`` is required. If also ``None``, then a ``ValueError``
98+
will be raised.
9599
predictor_cls (callable[str, sagemaker.session.Session]): A function
96100
to call to create a predictor with an endpoint name and
97101
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -108,22 +112,19 @@ def __init__(
108112
:class:`~sagemaker.model.FrameworkModel` and
109113
:class:`~sagemaker.model.Model`.
110114
"""
111-
super(MXNetModel, self).__init__(
112-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
113-
)
114-
115-
if py_version == "py2":
115+
# TODO: rename/unify image attribute to match across code base
116+
validate_version_or_image_args(framework_version, py_version, image)
117+
if py_version and py_version == "py2":
116118
logger.warning(
117119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
119123

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
123-
)
124+
super(MXNetModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
124127

125-
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.MXNET_VERSION
127128
self.model_server_workers = model_server_workers
128129

129130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def mxnet_version(request):
163163
return request.param
164164

165165

166+
@pytest.fixture(scope="module", params=["py2", "py3"])
167+
def mxnet_py_version(request):
168+
return request.param
169+
170+
166171
@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"])
167172
def pytorch_version(request):
168173
return request.param

tests/integ/test_local_mode.py

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _create_model(output_path):
6666
train_instance_type="local",
6767
output_path=output_path,
6868
framework_version=mxnet_full_version,
69+
py_version=PYTHON_VERSION,
6970
sagemaker_session=sagemaker_local_session,
7071
)
7172

@@ -188,6 +189,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
188189
train_instance_count=1,
189190
train_instance_type="local",
190191
framework_version=mxnet_full_version,
192+
py_version=PYTHON_VERSION,
191193
sagemaker_session=LocalNoS3Session(),
192194
)
193195

@@ -242,6 +244,7 @@ def test_local_transform_mxnet(
242244
train_instance_count=1,
243245
train_instance_type="local",
244246
framework_version=mxnet_full_version,
247+
py_version=PYTHON_VERSION,
245248
sagemaker_session=sagemaker_local_session,
246249
)
247250

tests/integ/test_neo_mxnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_inferentia_deploy_model(
131131
role,
132132
entry_point=script_path,
133133
framework_version=INF_MXNET_VERSION,
134+
py_version=PYTHON_VERSION,
134135
sagemaker_session=sagemaker_session,
135136
)
136137

tests/integ/test_transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type):
5151
train_instance_type=cpu_instance_type,
5252
sagemaker_session=sagemaker_session,
5353
framework_version=mxnet_full_version,
54+
py_version=PYTHON_VERSION,
5455
)
5556

5657
train_input = mx.sagemaker_session.upload_data(

tests/unit/test_fw_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,16 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog):
12721272
training_instance_type=train_instance_type, distributions=distributions
12731273
)
12741274
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text
1275+
1276+
1277+
def test_validate_version_or_image_args():
1278+
for good_args in [("", "", None), (None, "", ""), ("", None, "")]:
1279+
kwargs = dict(zip(("framework_version", "py_version", "image_name"), good_args))
1280+
assert fw_utils.validate_version_or_image_args(**kwargs)
1281+
1282+
1283+
def test_validate_version_or_image_args_raises():
1284+
for bad_args in [(None, None, None), (None, "", None), ("", None, None)]:
1285+
kwargs = dict(zip(("framework_version", "py_version", "image_name"), bad_args))
1286+
with pytest.raises(ValueError):
1287+
fw_utils.validate_version_or_image_args(**kwargs)

0 commit comments

Comments
 (0)