Skip to content

breaking: require framework_version, py_version for mxnet #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions doc/frameworks/mxnet/using_mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ The following code sample shows how you train a custom MXNet script "train.py".
mxnet_estimator = MXNet('train.py',
train_instance_type='ml.p2.xlarge',
train_instance_count=1,
framework_version='1.3.0',
framework_version='1.6.0',
py_version='py3',
hyperparameters={'batch-size': 100,
'epochs': 10,
'learning-rate': 0.1})
Expand Down Expand Up @@ -230,10 +231,10 @@ If you use the ``MXNet`` estimator to train the model, you can call ``deploy`` t

# Train my estimator
mxnet_estimator = MXNet('train.py',
train_instance_type='ml.p2.xlarge',
train_instance_count=1,
framework_version='1.6.0',
py_version='py3',
framework_version='1.6.0')
train_instance_type='ml.p2.xlarge',
train_instance_count=1)
mxnet_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to an Amazon SageMaker Endpoint and get a Predictor
Expand All @@ -247,8 +248,8 @@ If using a pretrained model, create an ``MXNetModel`` object, and then call ``de
mxnet_model = MXNetModel(model_data='s3://my_bucket/pretrained_model/model.tar.gz',
role=role,
entry_point='inference.py',
py_version='py3',
framework_version='1.6.0')
framework_version='1.6.0',
py_version='py3')
predictor = mxnet_model.deploy(instance_type='ml.m4.xlarge',
initial_instance_count=1)

Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/cli/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from sagemaker.cli.common import HostCommand, TrainCommand
from sagemaker.mxnet import defaults


def train(args):
Expand All @@ -40,13 +41,14 @@ def create_estimator(self):
from sagemaker.mxnet.estimator import MXNet

return MXNet(
self.script,
entry_point=self.script,
framework_version=defaults.MXNET_VERSION,
py_version=self.python,
role=self.role_name,
base_job_name=self.job_name,
train_instance_count=self.instance_count,
train_instance_type=self.instance_type,
hyperparameters=self.hyperparameters,
py_version=self.python,
)


Expand All @@ -64,6 +66,7 @@ def create_model(self, model_url):
model_data=model_url,
role=self.role_name,
entry_point=self.script,
framework_version=defaults.MXNET_VERSION,
py_version=self.python,
name=self.endpoint_name,
env=self.environment,
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,24 @@ def _region_supports_debugger(region_name):

"""
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS


def validate_version_or_image_args(framework_version, py_version, image_name):
"""Checks if version or image arguments are specified.

Validates framework and model arguments to enforce version or image specification.

Args:
framework_version (str): The version of the framework.
py_version (str): The version of python.
image_name (str): The URI of the image.

Raises:
ValueError: if `image_name` is None and either `framework_version` or `py_version` is
None.
"""
if (framework_version is None or py_version is None) and image_name is None:
raise ValueError(
"framework_version or py_version was None, yet image_name was also None. "
"Either specify both framework_version and py_version, or specify image_name."
)
74 changes: 41 additions & 33 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
is_version_equal_or_higher,
python_deprecation_warning,
parameter_v2_rename_warning,
is_version_equal_or_higher,
validate_version_or_image_args,
warn_if_parameter_server_with_multi_gpu,
)
from sagemaker.mxnet import defaults
Expand All @@ -43,10 +43,10 @@ class MXNet(Framework):
def __init__(
self,
entry_point,
framework_version=None,
py_version=None,
source_dir=None,
hyperparameters=None,
py_version="py2",
framework_version=None,
image_name=None,
distributions=None,
**kwargs
Expand All @@ -73,6 +73,13 @@ def __init__(
file which should be executed as the entry point to training.
If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): MXNet version you want to use for executing
your model training code. Defaults to `None`. Required unless
``image_name`` is provided. List of supported versions.
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
unless ``image_name`` is provided.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other training source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -84,12 +91,6 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2'). One of 'py2' or 'py3'.
framework_version (str): MXNet version you want to use for executing
your model training code. List of supported versions
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
If not specified, this will default to 1.2.1.
image_name (str): If specified, the estimator will use this image for training and
hosting, instead of selecting the appropriate SageMaker official image based on
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
Expand All @@ -98,6 +99,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
distributions (dict): A dictionary with information on how to run distributed
training (default: None). To have parameter servers launched for training,
set this value to be ``{'parameter_server': {'enabled': True}}``.
Expand All @@ -110,34 +114,32 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version and py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.MXNET_VERSION
self.framework_version = framework_version
self.py_version = py_version

if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for MXNet v1.6 or greater:
if is_version_equal_or_higher([1, 6], self.framework_version):
if self.framework_version and is_version_equal_or_higher(
[1, 6], self.framework_version
):
kwargs["enable_sagemaker_metrics"] = True

super(MXNet, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

if distributions is not None:
logger.warning(parameter_v2_rename_warning("distributions", "distribution"))
train_instance_type = kwargs.get("train_instance_type")
warn_if_parameter_server_with_multi_gpu(
training_instance_type=train_instance_type, distributions=distributions
)

self.py_version = py_version
self._configure_distribution(distributions)

def _configure_distribution(self, distributions):
Expand All @@ -148,7 +150,10 @@ def _configure_distribution(self, distributions):
if distributions is None:
return

if self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION:
if (
self.framework_version
and self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION
):
raise ValueError(
"The distributions option is valid for only versions {} and higher".format(
".".join(self._LOWEST_SCRIPT_MODE_VERSION)
Expand Down Expand Up @@ -221,12 +226,12 @@ def create_model(
self.model_data,
role or self.role,
entry_point or self.entry_point,
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down Expand Up @@ -254,22 +259,25 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

# We switched image tagging scheme from regular image version (e.g. '1.0') to more
# expressive containing framework version, device type and python version
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
# '0.12' framework version otherwise extract framework version from the tag itself.
if tag is None:
framework_version = None
elif tag == "1.0":
framework_version = "0.12"
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version

# We switched image tagging scheme from regular image version (e.g. '1.0') to more
# expressive containing framework version, device type and python version
# (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
# '0.12' framework version otherwise extract framework version from the tag itself.
init_params["framework_version"] = (
"0.12" if tag == "1.0" else framework_version_from_tag(tag)
)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
38 changes: 20 additions & 18 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet import defaults
Expand Down Expand Up @@ -65,9 +65,9 @@ def __init__(
model_data,
role,
entry_point,
image=None,
py_version="py2",
framework_version=None,
py_version=None,
image=None,
predictor_cls=MXNetPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -86,12 +86,18 @@ def __init__(
file which should be executed as the entry point to model
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): MXNet version you want to use for executing
your model training code. Defaults to ``None``. Required unless
``image_name`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image_name`` is provided.
image (str): A Docker image URI (default: None). If not specified, a
default image for MXNet will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2').
framework_version (str): MXNet version you want to use for executing
your model training code.

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -108,22 +114,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(MXNetModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

if py_version == "py2":
validate_version_or_image_args(framework_version, py_version, image)
if py_version and py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
)
super(MXNetModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.MXNET_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def mxnet_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def mxnet_py_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"])
def pytorch_version(request):
return request.param
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _create_model(output_path):
train_instance_type="local",
output_path=output_path,
framework_version=mxnet_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_local_session,
)

Expand Down Expand Up @@ -188,6 +189,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
train_instance_count=1,
train_instance_type="local",
framework_version=mxnet_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=LocalNoS3Session(),
)

Expand Down Expand Up @@ -242,6 +244,7 @@ def test_local_transform_mxnet(
train_instance_count=1,
train_instance_type="local",
framework_version=mxnet_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_local_session,
)

Expand Down
1 change: 1 addition & 0 deletions tests/integ/test_neo_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_inferentia_deploy_model(
role,
entry_point=script_path,
framework_version=INF_MXNET_VERSION,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
)

Expand Down
1 change: 1 addition & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type):
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version,
py_version=PYTHON_VERSION,
)

train_input = mx.sagemaker_session.upload_data(
Expand Down
Loading