diff --git a/doc/frameworks/mxnet/using_mxnet.rst b/doc/frameworks/mxnet/using_mxnet.rst index db7bf1b520..5dcae6ae3b 100644 --- a/doc/frameworks/mxnet/using_mxnet.rst +++ b/doc/frameworks/mxnet/using_mxnet.rst @@ -183,7 +183,8 @@ The following code sample shows how you train a custom MXNet script "train.py". mxnet_estimator = MXNet('train.py', train_instance_type='ml.p2.xlarge', train_instance_count=1, - framework_version='1.3.0', + framework_version='1.6.0', + py_version='py3', hyperparameters={'batch-size': 100, 'epochs': 10, 'learning-rate': 0.1}) @@ -230,10 +231,10 @@ If you use the ``MXNet`` estimator to train the model, you can call ``deploy`` t # Train my estimator mxnet_estimator = MXNet('train.py', - train_instance_type='ml.p2.xlarge', - train_instance_count=1, + framework_version='1.6.0', py_version='py3', - framework_version='1.6.0') + train_instance_type='ml.p2.xlarge', + train_instance_count=1) mxnet_estimator.fit('s3://my_bucket/my_training_data/') # Deploy my estimator to an Amazon SageMaker Endpoint and get a Predictor @@ -247,8 +248,8 @@ If using a pretrained model, create an ``MXNetModel`` object, and then call ``de mxnet_model = MXNetModel(model_data='s3://my_bucket/pretrained_model/model.tar.gz', role=role, entry_point='inference.py', - py_version='py3', - framework_version='1.6.0') + framework_version='1.6.0', + py_version='py3') predictor = mxnet_model.deploy(instance_type='ml.m4.xlarge', initial_instance_count=1) diff --git a/src/sagemaker/cli/mxnet.py b/src/sagemaker/cli/mxnet.py index 244c69d7ee..3406cdef8e 100644 --- a/src/sagemaker/cli/mxnet.py +++ b/src/sagemaker/cli/mxnet.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from sagemaker.cli.common import HostCommand, TrainCommand +from sagemaker.mxnet import defaults def train(args): @@ -40,13 +41,14 @@ def create_estimator(self): from sagemaker.mxnet.estimator import MXNet return MXNet( - self.script, + entry_point=self.script, + framework_version=defaults.MXNET_VERSION, + py_version=self.python, role=self.role_name, base_job_name=self.job_name, train_instance_count=self.instance_count, train_instance_type=self.instance_type, hyperparameters=self.hyperparameters, - py_version=self.python, ) @@ -64,6 +66,7 @@ def create_model(self, model_url): model_data=model_url, role=self.role_name, entry_point=self.script, + framework_version=defaults.MXNET_VERSION, py_version=self.python, name=self.endpoint_name, env=self.environment, diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 9141ae8c72..5a957b76b6 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -681,3 +681,24 @@ def _region_supports_debugger(region_name): """ return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS + + +def validate_version_or_image_args(framework_version, py_version, image_name): + """Checks if version or image arguments are specified. + + Validates framework and model arguments to enforce version or image specification. + + Args: + framework_version (str): The version of the framework. + py_version (str): The version of Python. + image_name (str): The URI of the image. + + Raises: + ValueError: if `image_name` is None and either `framework_version` or `py_version` is + None. + """ + if (framework_version is None or py_version is None) and image_name is None: + raise ValueError( + "framework_version or py_version was None, yet image_name was also None. " + "Either specify both framework_version and py_version, or specify image_name." + ) diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index a1432dc259..06a4c360c3 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -19,10 +19,10 @@ from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, - empty_framework_version_warning, + is_version_equal_or_higher, python_deprecation_warning, parameter_v2_rename_warning, - is_version_equal_or_higher, + validate_version_or_image_args, warn_if_parameter_server_with_multi_gpu, ) from sagemaker.mxnet import defaults @@ -43,10 +43,10 @@ class MXNet(Framework): def __init__( self, entry_point, + framework_version=None, + py_version=None, source_dir=None, hyperparameters=None, - py_version="py2", - framework_version=None, image_name=None, distributions=None, **kwargs @@ -73,6 +73,13 @@ def __init__( file which should be executed as the entry point to training. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. + framework_version (str): MXNet version you want to use for executing + your model training code. Defaults to `None`. Required unless + ``image_name`` is provided. List of supported versions. + https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators. + py_version (str): Python version you want to use for executing your + model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required + unless ``image_name`` is provided. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must @@ -84,12 +91,6 @@ def __init__( SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. - py_version (str): Python version you want to use for executing your - model training code (default: 'py2'). One of 'py2' or 'py3'. - framework_version (str): MXNet version you want to use for executing - your model training code. List of supported versions - https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators. - If not specified, this will default to 1.2.1. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. @@ -98,6 +99,9 @@ def __init__( * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0`` * ``custom-image:latest`` + If ``framework_version`` or ``py_version`` are ``None``, then + ``image_name`` is required. If also ``None``, then a ``ValueError`` + will be raised. distributions (dict): A dictionary with information on how to run distributed training (default: None). To have parameter servers launched for training, set this value to be ``{'parameter_server': {'enabled': True}}``. @@ -110,26 +114,25 @@ def __init__( :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ - if framework_version is None: + validate_version_or_image_args(framework_version, py_version, image_name) + if py_version and py_version == "py2": logger.warning( - empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION) + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) ) - self.framework_version = framework_version or defaults.MXNET_VERSION + self.framework_version = framework_version + self.py_version = py_version if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for MXNet v1.6 or greater: - if is_version_equal_or_higher([1, 6], self.framework_version): + if self.framework_version and is_version_equal_or_higher( + [1, 6], self.framework_version + ): kwargs["enable_sagemaker_metrics"] = True super(MXNet, self).__init__( entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs ) - if py_version == "py2": - logger.warning( - python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) - ) - if distributions is not None: logger.warning(parameter_v2_rename_warning("distributions", "distribution")) train_instance_type = kwargs.get("train_instance_type") @@ -137,7 +140,6 @@ def __init__( training_instance_type=train_instance_type, distributions=distributions ) - self.py_version = py_version self._configure_distribution(distributions) def _configure_distribution(self, distributions): @@ -148,7 +150,10 @@ def _configure_distribution(self, distributions): if distributions is None: return - if self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION: + if ( + self.framework_version + and self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION + ): raise ValueError( "The distributions option is valid for only versions {} and higher".format( ".".join(self._LOWEST_SCRIPT_MODE_VERSION) @@ -221,12 +226,12 @@ def create_model( self.model_data, role or self.role, entry_point or self.entry_point, + framework_version=self.framework_version, + py_version=self.py_version, source_dir=(source_dir or self._model_source_dir()), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, container_log_level=self.container_log_level, code_location=self.code_location, - py_version=self.py_version, - framework_version=self.framework_version, model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), @@ -254,22 +259,25 @@ class constructor image_name = init_params.pop("image") framework, py_version, tag, _ = framework_name_from_image(image_name) + # We switched image tagging scheme from regular image version (e.g. '1.0') to more + # expressive containing framework version, device type and python version + # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a + # '0.12' framework version otherwise extract framework version from the tag itself. + if tag is None: + framework_version = None + elif tag == "1.0": + framework_version = "0.12" + else: + framework_version = framework_version_from_tag(tag) + init_params["framework_version"] = framework_version + init_params["py_version"] = py_version + if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. init_params["image_name"] = image_name return init_params - init_params["py_version"] = py_version - - # We switched image tagging scheme from regular image version (e.g. '1.0') to more - # expressive containing framework version, device type and python version - # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a - # '0.12' framework version otherwise extract framework version from the tag itself. - init_params["framework_version"] = ( - "0.12" if tag == "1.0" else framework_version_from_tag(tag) - ) - training_job_name = init_params["base_job_name"] if framework != cls.__framework_name__: diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 47691ff2b7..6373a9dfad 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -22,7 +22,7 @@ create_image_uri, model_code_key_prefix, python_deprecation_warning, - empty_framework_version_warning, + validate_version_or_image_args, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.mxnet import defaults @@ -65,9 +65,9 @@ def __init__( model_data, role, entry_point, - image=None, - py_version="py2", framework_version=None, + py_version=None, + image=None, predictor_cls=MXNetPredictor, model_server_workers=None, **kwargs @@ -86,12 +86,18 @@ def __init__( file which should be executed as the entry point to model hosting. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. + framework_version (str): MXNet version you want to use for executing + your model training code. Defaults to ``None``. Required unless + ``image`` is provided. + py_version (str): Python version you want to use for executing your + model training code. Defaults to ``None``. Required unless + ``image`` is provided. image (str): A Docker image URI (default: None). If not specified, a default image for MXNet will be used. - py_version (str): Python version you want to use for executing your - model training code (default: 'py2'). - framework_version (str): MXNet version you want to use for executing - your model training code. + + If ``framework_version`` or ``py_version`` are ``None``, then + ``image`` is required. If also ``None``, then a ``ValueError`` + will be raised. predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the @@ -108,22 +114,18 @@ def __init__( :class:`~sagemaker.model.FrameworkModel` and :class:`~sagemaker.model.Model`. """ - super(MXNetModel, self).__init__( - model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs - ) - - if py_version == "py2": + validate_version_or_image_args(framework_version, py_version, image) + if py_version and py_version == "py2": logger.warning( python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) ) + self.framework_version = framework_version + self.py_version = py_version - if framework_version is None: - logger.warning( - empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION) - ) + super(MXNetModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) - self.py_version = py_version - self.framework_version = framework_version or defaults.MXNET_VERSION self.model_server_workers = model_server_workers def prepare_container_def(self, instance_type, accelerator_type=None): diff --git a/tests/conftest.py b/tests/conftest.py index b46fbd035e..8011f612c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -163,6 +163,11 @@ def mxnet_version(request): return request.param +@pytest.fixture(scope="module", params=["py2", "py3"]) +def mxnet_py_version(request): + return request.param + + @pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"]) def pytorch_version(request): return request.param diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index dd203810dc..7e448ff542 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -66,6 +66,7 @@ def _create_model(output_path): train_instance_type="local", output_path=output_path, framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, sagemaker_session=sagemaker_local_session, ) @@ -188,6 +189,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version): train_instance_count=1, train_instance_type="local", framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, sagemaker_session=LocalNoS3Session(), ) @@ -242,6 +244,7 @@ def test_local_transform_mxnet( train_instance_count=1, train_instance_type="local", framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, sagemaker_session=sagemaker_local_session, ) diff --git a/tests/integ/test_neo_mxnet.py b/tests/integ/test_neo_mxnet.py index 1786347554..e3778e892c 100644 --- a/tests/integ/test_neo_mxnet.py +++ b/tests/integ/test_neo_mxnet.py @@ -131,6 +131,7 @@ def test_inferentia_deploy_model( role, entry_point=script_path, framework_version=INF_MXNET_VERSION, + py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, ) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index ef9ad0929c..d925de0af3 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -51,6 +51,7 @@ def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type): train_instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, ) train_input = mx.sagemaker_session.upload_data( diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index d06a415eef..9a4fe32c38 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -1272,3 +1272,16 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog): training_instance_type=train_instance_type, distributions=distributions ) assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text + + +def test_validate_version_or_image_args_not_raises(): + good_args = [("1.0", "py3", None), (None, "py3", "my:uri"), ("1.0", None, "my:uri")] + for framework_version, py_version, image_name in good_args: + fw_utils.validate_version_or_image_args(framework_version, py_version, image_name) + + +def test_validate_version_or_image_args_raises(): + bad_args = [(None, None, None), (None, "py3", None), ("1.0", None, None)] + for framework_version, py_version, image_name in bad_args: + with pytest.raises(ValueError): + fw_utils.validate_version_or_image_args(framework_version, py_version, image_name) diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index 84aed3fc58..320332987a 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -27,7 +27,11 @@ MXNET_MODEL_DATA = "s3://mybucket/mxnet_path/model.tar.gz" MXNET_MODEL_NAME = "dummy-mxnet-model" MXNET_ROLE = "DummyMXNetRole" -MXNET_IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.2-cpu-py2" +MXNET_FRAMEWORK_VERSION = "1.2" +MXNET_PY_VERSION = "py2" +MXNET_IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-cpu-{}".format( + MXNET_FRAMEWORK_VERSION, MXNET_PY_VERSION +) DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") IMAGE = "123456789012.dkr.ecr.dummyregion.amazonaws.com/dummyimage:latest" @@ -100,8 +104,10 @@ def multi_data_model(sagemaker_session): def mxnet_model(sagemaker_session): return MXNetModel( MXNET_MODEL_DATA, - role=MXNET_ROLE, entry_point=ENTRY_POINT, + framework_version=MXNET_FRAMEWORK_VERSION, + py_version=MXNET_PY_VERSION, + role=MXNET_ROLE, sagemaker_session=sagemaker_session, name=MXNET_MODEL_NAME, enable_network_isolation=True, diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index e073462a3f..ecd9b3146d 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -37,12 +37,9 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.4xlarge" ACCELERATOR_TYPE = "ml.eia.medium" -IMAGE_REPO_NAME = "sagemaker-mxnet" -IMAGE_REPO_SERVING_NAME = "sagemaker-mxnet-serving" -JOB_NAME = "{}-{}".format(IMAGE_REPO_NAME, TIMESTAMP) +IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3" COMPILATION_JOB_NAME = "{}-{}".format("compilation-sagemaker-mxnet", TIMESTAMP) FRAMEWORK = "mxnet" -FULL_IMAGE_URI = "520713654638.dkr.ecr.us-west-2.amazonaws.com/{}:{}-{}-{}" ROLE = "Dummy" REGION = "us-west-2" GPU = "ml.p2.xlarge" @@ -88,29 +85,25 @@ def sagemaker_session(): return session +def _is_mms_version(mxnet_version): + return parse_version(MXNetModel._LOWEST_MMS_VERSION) <= parse_version(mxnet_version) + + @pytest.fixture() def skip_if_mms_version(mxnet_version): - if parse_version(MXNetModel._LOWEST_MMS_VERSION) <= parse_version(mxnet_version): + if _is_mms_version(mxnet_version): pytest.skip("Skipping because this version uses MMS") @pytest.fixture() def skip_if_not_mms_version(mxnet_version): - if parse_version(MXNetModel._LOWEST_MMS_VERSION) > parse_version(mxnet_version): + if not _is_mms_version(mxnet_version): pytest.skip("Skipping because this version does not use MMS") -def _get_full_image_uri(version, repo=IMAGE_REPO_NAME, processor="cpu", py_version="py2"): - return FULL_IMAGE_URI.format(repo, version, processor, py_version) - - -def _get_full_image_uri_with_ei(version, repo=IMAGE_REPO_NAME, processor="cpu", py_version="py2"): - return FULL_IMAGE_URI.format("{}-eia".format(repo), version, processor, py_version) - - -def _create_train_job(version): +def _get_train_args(job_name): return { - "image": _get_full_image_uri(version), + "image": IMAGE, "input_mode": "File", "input_config": [ { @@ -124,7 +117,7 @@ def _create_train_job(version): } ], "role": ROLE, - "job_name": JOB_NAME, + "job_name": job_name, "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, "resource_config": { "InstanceType": "ml.c4.4xlarge", @@ -135,9 +128,9 @@ def _create_train_job(version): "sagemaker_program": json.dumps("dummy_script.py"), "sagemaker_enable_cloudwatch_metrics": "false", "sagemaker_container_log_level": str(logging.INFO), - "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_job_name": json.dumps(job_name), "sagemaker_submit_directory": json.dumps( - "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, job_name) ), "sagemaker_region": '"us-west-2"', }, @@ -153,6 +146,20 @@ def _create_train_job(version): } +def _get_environment(submit_directory, model_url, image_name): + return { + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": submit_directory, + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + }, + "Image": image_name, + "ModelDataUrl": model_url, + } + + def _create_compilation_job(input_shape, output_location): return { "input_model_config": { @@ -175,16 +182,17 @@ def _neo_inference_image(mxnet_version): @patch("sagemaker.utils.create_tar_file", MagicMock()) -def test_create_model(sagemaker_session, mxnet_version): +def test_create_model(sagemaker_session, mxnet_version, mxnet_py_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version, container_log_level=container_log_level, base_job_name="job", source_dir=source_dir, @@ -196,7 +204,7 @@ def test_create_model(sagemaker_session, mxnet_version): assert model.sagemaker_session == sagemaker_session assert model.framework_version == mxnet_version - assert model.py_version == mx.py_version + assert model.py_version == mxnet_py_version assert model.entry_point == SCRIPT_PATH assert model.role == ROLE assert model.name == job_name @@ -206,12 +214,14 @@ def test_create_model(sagemaker_session, mxnet_version): assert model.vpc_config is None -def test_create_model_with_optional_params(sagemaker_session): +def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxnet_py_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" enable_cloudwatch_metrics = "true" mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -251,6 +261,8 @@ def test_create_model_with_custom_image(sagemaker_session): custom_image = "mxnet:2.0" mx = MXNet( entry_point=SCRIPT_PATH, + framework_version="2.0", + py_version="py3", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -274,18 +286,30 @@ def test_create_model_with_custom_image(sagemaker_session): assert model.source_dir == source_dir -@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.utils.create_tar_file") +@patch("sagemaker.utils.repack_model") @patch("time.strftime", return_value=TIMESTAMP) -def test_mxnet(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): +@patch("sagemaker.mxnet.model.create_image_uri", return_value=IMAGE) +@patch("sagemaker.estimator.create_image_uri", return_value=IMAGE) +def test_mxnet( + train_image_uri, + model_image_uri, + strftime, + repack_model, + create_tar_file, + sagemaker_session, + mxnet_version, + mxnet_py_version, +): mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version, ) - inputs = "s3://mybucket/train" mx.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) @@ -295,97 +319,41 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(mxnet_version) + actual_train_args = sagemaker_session.method_calls[0][2] + job_name = actual_train_args["job_name"] + expected_train_args = _get_train_args(job_name) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["experiment_config"] = EXPERIMENT_CONFIG - actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = mx.create_model() - expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2" - environment = { - "Environment": { - "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz".format( - TIMESTAMP - ), - "SAGEMAKER_PROGRAM": "dummy_script.py", - "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", - "SAGEMAKER_REGION": "us-west-2", - "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", - }, - "Image": expected_image_base.format(mxnet_version), - "ModelDataUrl": "s3://m/m.tar.gz", - } - assert environment == model.prepare_container_def(GPU) + actual_environment = model.prepare_container_def(GPU) + submit_directory = actual_environment["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] + model_url = actual_environment["ModelDataUrl"] + expected_environment = _get_environment(submit_directory, model_url, IMAGE) + assert actual_environment == expected_environment assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = mx.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) + assert _is_mms_version(mxnet_version) ^ (create_tar_file.called and not repack_model.called) -@patch("sagemaker.utils.repack_model") +@patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("time.strftime", return_value=TIMESTAMP) -def test_mxnet_mms_version( - strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version +def test_mxnet_neo( + strftime, sagemaker_session, mxnet_version, mxnet_py_version, skip_if_mms_version ): mx = MXNet( entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, framework_version=mxnet_version, - ) - - inputs = "s3://mybucket/train" - - mx.fit(inputs=inputs) - - sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ["train", "logs_for_job"] - boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] - - expected_train_args = _create_train_job(mxnet_version) - expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs - - actual_train_args = sagemaker_session.method_calls[0][2] - assert actual_train_args == expected_train_args - - model = mx.create_model() - - expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, "gpu") - - environment = { - "Environment": { - "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz", - "SAGEMAKER_PROGRAM": "dummy_script.py", - "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", - "SAGEMAKER_REGION": "us-west-2", - "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", - }, - "Image": expected_image_base.format(mxnet_version), - "ModelDataUrl": "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz", - } - assert environment == model.prepare_container_def(GPU) - - assert "cpu" in model.prepare_container_def(CPU)["Image"] - predictor = mx.deploy(1, GPU) - assert isinstance(predictor, MXNetPredictor) - - -@patch("sagemaker.utils.create_tar_file", MagicMock()) -@patch("time.strftime", return_value=TIMESTAMP) -def test_mxnet_neo(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): - mx = MXNet( - entry_point=SCRIPT_PATH, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version, ) inputs = "s3://mybucket/train" @@ -426,22 +394,30 @@ def test_mxnet_neo(strftime, sagemaker_session, mxnet_version, skip_if_mms_versi @patch("sagemaker.utils.create_tar_file", MagicMock()) -def test_model(sagemaker_session): +def test_model(sagemaker_session, mxnet_version, mxnet_py_version, skip_if_mms_version): model = MXNetModel( - MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, + sagemaker_session=sagemaker_session, ) predictor = model.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) @patch("sagemaker.utils.repack_model") -def test_model_mms_version(repack_model, sagemaker_session): +def test_model_mms_version( + repack_model, sagemaker_session, mxnet_version, mxnet_py_version, skip_if_not_mms_version +): model_kms_key = "kms-key" model = MXNetModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - framework_version=MXNetModel._LOWEST_MMS_VERSION, + framework_version=mxnet_version, + py_version=mxnet_py_version, sagemaker_session=sagemaker_session, name="test-mxnet-model", model_kms_key=model_kms_key, @@ -467,45 +443,33 @@ def test_model_mms_version(repack_model, sagemaker_session): assert isinstance(predictor, MXNetPredictor) -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_model_image_accelerator(sagemaker_session): - model = MXNetModel( - MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session - ) - container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert container_def["Image"] == _get_full_image_uri_with_ei(defaults.MXNET_VERSION) - - -@patch("sagemaker.utils.repack_model", MagicMock()) -def test_model_image_accelerator_mms_version(sagemaker_session): +@patch("sagemaker.fw_utils.tar_and_upload_dir") +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.mxnet.model.create_image_uri", return_value=IMAGE) +def test_model_image_accelerator( + create_image_uri, + repack_model, + tar_and_upload, + sagemaker_session, + mxnet_version, + mxnet_py_version, +): model = MXNetModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - framework_version=MXNetModel._LOWEST_MMS_VERSION, + framework_version=mxnet_version, + py_version=mxnet_py_version, sagemaker_session=sagemaker_session, ) container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert container_def["Image"] == _get_full_image_uri_with_ei( - MXNetModel._LOWEST_MMS_VERSION, IMAGE_REPO_SERVING_NAME - ) + assert container_def["Image"] == IMAGE + assert _is_mms_version(mxnet_version) ^ (tar_and_upload.called and not repack_model.called) -def test_train_image_default(sagemaker_session): - mx = MXNet( - entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - ) - - assert _get_full_image_uri(defaults.MXNET_VERSION) in mx.train_image() - - -def test_attach(sagemaker_session, mxnet_version): - training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:{}-cpu-py2".format( - mxnet_version +def test_attach(sagemaker_session, mxnet_version, mxnet_py_version): + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-{0}-cpu:{1}-cpu-{0}".format( + mxnet_py_version, mxnet_version ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -538,7 +502,7 @@ def test_attach(sagemaker_session, mxnet_version): estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" - assert estimator.py_version == "py2" + assert estimator.py_version == mxnet_py_version assert estimator.framework_version == mxnet_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 @@ -679,12 +643,13 @@ def test_attach_custom_image(sagemaker_session): def test_estimator_script_mode_launch_parameter_server(warning, sagemaker_session): mx = MXNet( entry_point=SCRIPT_PATH, + framework_version="1.3.0", + py_version="py2", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, - framework_version="1.3.0", ) assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == "true" warning.assert_called_with("distributions", "distribution") @@ -693,12 +658,13 @@ def test_estimator_script_mode_launch_parameter_server(warning, sagemaker_sessio def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session): mx = MXNet( entry_point=SCRIPT_PATH, + framework_version="1.3.0", + py_version="py2", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, distributions={"parameter_server": {"enabled": False}}, - framework_version="1.3.0", ) assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == "false" @@ -707,12 +673,13 @@ def test_estimator_wrong_version_launch_parameter_server(sagemaker_session): with pytest.raises(ValueError) as e: MXNet( entry_point=SCRIPT_PATH, + framework_version="1.2.1", + py_version="py2", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, - framework_version="1.2.1", ) assert "The distributions option is valid for only versions 1.3 and higher" in str(e) @@ -721,11 +688,12 @@ def test_estimator_wrong_version_launch_parameter_server(sagemaker_session): def test_estimator_py2_warning(warning, sagemaker_session): estimator = MXNet( entry_point=SCRIPT_PATH, + framework_version="1.2.1", + py_version="py2", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - py_version="py2", ) assert estimator.py_version == "py2" @@ -738,41 +706,14 @@ def test_model_py2_warning(warning, sagemaker_session): MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session, + framework_version="1.2.1", py_version="py2", + sagemaker_session=sagemaker_session, ) assert model.py_version == "py2" warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) -@patch("sagemaker.mxnet.estimator.empty_framework_version_warning") -def test_empty_framework_version(warning, sagemaker_session): - mx = MXNet( - entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - framework_version=None, - ) - - assert mx.framework_version == defaults.MXNET_VERSION - warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION) - - -@patch("sagemaker.mxnet.model.empty_framework_version_warning") -def test_model_empty_framework_version(warning, sagemaker_session): - model = MXNetModel( - MODEL_DATA, - role=ROLE, - entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session, - framework_version=None, - ) - assert model.framework_version == defaults.MXNET_VERSION - warning.assert_called_with(defaults.MXNET_VERSION, defaults.LATEST_VERSION) - - def test_create_model_with_custom_hosting_image(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" @@ -780,6 +721,8 @@ def test_create_model_with_custom_hosting_image(sagemaker_session): custom_hosting_image = "mxnet_hosting:2.0" mx = MXNet( entry_point=SCRIPT_PATH, + framework_version="2.0", + py_version="py3", role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -796,9 +739,11 @@ def test_create_model_with_custom_hosting_image(sagemaker_session): assert model.image == custom_hosting_image -def test_mx_enable_sm_metrics(sagemaker_session): +def test_mx_enable_sm_metrics(sagemaker_session, mxnet_version, mxnet_py_version): mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -808,9 +753,11 @@ def test_mx_enable_sm_metrics(sagemaker_session): assert mx.enable_sagemaker_metrics -def test_mx_disable_sm_metrics(sagemaker_session): +def test_mx_disable_sm_metrics(sagemaker_session, mxnet_version, mxnet_py_version): mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -820,36 +767,30 @@ def test_mx_disable_sm_metrics(sagemaker_session): assert not mx.enable_sagemaker_metrics -def test_mx_disable_sm_metrics_if_pt_ver_is_less_than_1_6(sagemaker_session): - for fw_version in ["1.1", "1.2", "1.3", "1.4", "1.5"]: - mx = MXNet( - entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - framework_version=fw_version, - ) - assert mx.enable_sagemaker_metrics is None - - -def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session): - for fw_version in ["1.6", "1.7", "2.0", "2.1"]: - mx = MXNet( - entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - framework_version=fw_version, - ) +def test_mx_enable_sm_metrics_for_version(sagemaker_session, mxnet_version, mxnet_py_version): + mx = MXNet( + entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + version = tuple(int(s) for s in mxnet_version.split(".")) + lowest_version = (1, 6, 0)[: len(version)] + if version >= lowest_version: assert mx.enable_sagemaker_metrics + else: + assert mx.enable_sagemaker_metrics is None -def test_custom_image_estimator_deploy(sagemaker_session): +def test_custom_image_estimator_deploy(sagemaker_session, mxnet_version, mxnet_py_version): custom_image = "mycustomimage:latest" mx = MXNet( entry_point=SCRIPT_PATH, + framework_version=mxnet_version, + py_version=mxnet_py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 490ae96145..999d69339e 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -257,8 +257,9 @@ def test_s3_input_mode(sagemaker_session, tuner): script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") mxnet = MXNet( entry_point=script_path, - role=ROLE, framework_version=FRAMEWORK_VERSION, + py_version=PY_VERSION, + role=ROLE, train_instance_count=TRAIN_INSTANCE_COUNT, train_instance_type=TRAIN_INSTANCE_TYPE, sagemaker_session=sagemaker_session, @@ -423,8 +424,9 @@ def _create_multi_estimator_tuner(sagemaker_session): mxnet_script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") mxnet = MXNet( entry_point=mxnet_script_path, - role=ROLE, framework_version=FRAMEWORK_VERSION, + py_version=PY_VERSION, + role=ROLE, train_instance_count=TRAIN_INSTANCE_COUNT, train_instance_type=TRAIN_INSTANCE_TYPE, sagemaker_session=sagemaker_session, @@ -664,8 +666,9 @@ def test_analytics(tuner): def test_serialize_categorical_ranges_for_frameworks(sagemaker_session, tuner): tuner.estimator = MXNet( entry_point=SCRIPT_NAME, - role=ROLE, framework_version=FRAMEWORK_VERSION, + py_version=PY_VERSION, + role=ROLE, train_instance_count=TRAIN_INSTANCE_COUNT, train_instance_type=TRAIN_INSTANCE_TYPE, sagemaker_session=sagemaker_session, @@ -915,8 +918,9 @@ def test_fit_no_inputs(tuner, sagemaker_session): script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") tuner.estimator = MXNet( entry_point=script_path, - role=ROLE, framework_version=FRAMEWORK_VERSION, + py_version=PY_VERSION, + role=ROLE, train_instance_count=TRAIN_INSTANCE_COUNT, train_instance_type=TRAIN_INSTANCE_TYPE, sagemaker_session=sagemaker_session, diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index fb8d7acd11..084cf5edfc 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -37,6 +37,7 @@ SCRIPT_NAME = "my_script.py" FRAMEWORK_VERSION = "1.0.0" +PY_VERSION = "py3" INPUTS = "s3://mybucket/train" diff --git a/tox.ini b/tox.ini index 2c5a8f60c6..3495d1e444 100644 --- a/tox.ini +++ b/tox.ini @@ -66,6 +66,10 @@ commands = {env:IGNORE_COVERAGE:} coverage report --fail-under=86 extras = test +[testenv:py27] +setenv = + IGNORE_COVERAGE = true + [testenv:flake8] basepython = python3 skipdist = true