From 6e19e89b01a455532d77b099225da9392dcfb0ee Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Fri, 8 May 2020 16:01:58 -0700 Subject: [PATCH 01/14] fix: add v2 warning messages --- src/sagemaker/__init__.py | 7 +++++++ src/sagemaker/amazon/amazon_estimator.py | 5 +++++ src/sagemaker/amazon/kmeans.py | 11 +++++++++++ src/sagemaker/amazon/randomcutforest.py | 11 +++++++++++ src/sagemaker/estimator.py | 7 +++++++ src/sagemaker/fw_utils.py | 25 +++++++++++++++++++++++- src/sagemaker/inputs.py | 7 +++++++ src/sagemaker/model.py | 4 ++++ src/sagemaker/mxnet/estimator.py | 2 ++ src/sagemaker/s3.py | 21 ++++++++++++++++++++ src/sagemaker/session.py | 15 ++++++++++++++ src/sagemaker/tensorflow/estimator.py | 5 ++++- tests/unit/test_amazon_estimator.py | 9 +++++++++ tests/unit/test_inputs.py | 7 ++++++- tests/unit/test_kmeans.py | 8 +++++++- tests/unit/test_mxnet.py | 9 +++++++-- tests/unit/test_randomcutforest.py | 7 ++++++- tests/unit/test_s3.py | 6 +++++- 18 files changed, 158 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 6b35f5abb9..8714690762 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -13,6 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import +import logging import importlib_metadata from sagemaker import estimator, parameter, tuner # noqa: F401 @@ -61,3 +62,9 @@ from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401 __version__ = importlib_metadata.version("sagemaker") + +logging.getLogger("sagemaker").warning( + "SageMaker Python SDK v2 will no longer support Python 2. " + "Please see https://github.com/aws/sagemaker-python-sdk/issues/1459 " + "for more information" +) diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 17f17f1a13..bf76d0a94d 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -616,6 +616,11 @@ def get_image_uri(region_name, repo_name, repo_version=1): repo_name: repo_version: """ + logger.warning( + "'get_image_uri' method will be deprecated in favor of 'ImageURIProvider' class " + "in SageMaker Python SDK v2." + ) + if repo_name == "xgboost": if not _is_latest_xgboost_version(repo_version): logging.warning( diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index d6b4ddda20..35fa236378 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +import logging + from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -23,6 +25,9 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +logger = logging.getLogger("sagemaker") + + class KMeans(AmazonAlgorithmEstimatorBase): """Placeholder docstring""" @@ -154,6 +159,12 @@ def __init__( self.center_factor = center_factor self.eval_metrics = eval_metrics + if eval_metrics is not None: + logger.warning( + "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " + "in SageMaker Python SDK v2." + ) + def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): """Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing the latest s3 model data produced by this Estimator. diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index 8e188c95ae..f6a4e3c5c2 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +import logging + from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -23,6 +25,9 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +logger = logging.getLogger("sagemaker") + + class RandomCutForest(AmazonAlgorithmEstimatorBase): """Placeholder docstring""" @@ -119,6 +124,12 @@ def __init__( self.num_trees = num_trees self.eval_metrics = eval_metrics + if eval_metrics is not None: + logger.warning( + "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " + "in SageMaker Python SDK v2." + ) + def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): """Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing the latest s3 model data produced by this Estimator. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ddd4b6cc60..292bb40cbe 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1273,6 +1273,9 @@ def __init__( https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries (default: ``None``). """ + warnings.warn( + "Parameter 'image_name' will be renamed to 'image_uri' in SageMaker Python SDK v2." + ) self.image_name = image_name self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} super(Estimator, self).__init__( @@ -1635,6 +1638,10 @@ def __init__( self.container_log_level = container_log_level self.code_location = code_location self.image_name = image_name + if image_name is not None: + warnings.warn( + "Parameter 'image_name' will be renamed to 'image_uri' in SageMaker Python SDK v2." + ) self.uploaded_code = None diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 645a28f8f1..855fde8fb8 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -34,7 +34,10 @@ instantiated with positional or keyword arguments. """ -EMPTY_FRAMEWORK_VERSION_WARNING = "No framework_version specified, defaulting to version {}." +EMPTY_FRAMEWORK_VERSION_WARNING = ( + "No framework_version specified, defaulting to version {}. " + "framework_version will be required in SageMaker Python SDK v2." +) LATER_FRAMEWORK_VERSION_WARNING = ( "This is not the latest supported version. " "If you would like to use version {latest}, " @@ -52,6 +55,10 @@ "fully leverage all GPU cores; the parameter server will be configured to run " "only one worker per host regardless of the number of GPUs." ) +PARAMETER_V2_RENAME_WARNING = ( + "Parameter {v1_parameter_name} will be renamed to {v2_parameter_name} " + "in SageMaker Python SDK v2." +) EMPTY_FRAMEWORK_VERSION_ERROR = ( @@ -253,6 +260,11 @@ def create_image_uri( Returns: str: The appropriate image URI based on the given parameters. """ + logger.warning( + "'create_image_uri' will be deprecated in favor of 'ImageURIProvider' class " + "in SageMaker Python SDK v2." + ) + optimized_families = optimized_families or [] if py_version and py_version not in VALID_PY_VERSIONS: @@ -647,6 +659,17 @@ def python_deprecation_warning(framework, latest_supported_version): ) +def parameter_v2_rename_warning(v1_parameter_name, v2_parameter_name): + """ + Args: + v1_parameter_name: + v2_parameter_name: + """ + return PARAMETER_V2_RENAME_WARNING.format( + v1_parameter_name=v1_parameter_name, v2_parameter_name=v2_parameter_name + ) + + def _region_supports_debugger(region_name): """Returns boolean indicating whether the region supports Amazon SageMaker Debugger. diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index c6b0659c11..63dccb1555 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -13,9 +13,13 @@ """Amazon SageMaker channel configurations for S3 data sources and file system data sources""" from __future__ import absolute_import, print_function +import logging + FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] +logger = logging.getLogger("sagemaker") + class s3_input(object): """Amazon SageMaker channel configurations for S3 data sources. @@ -76,6 +80,9 @@ def __init__( this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ + logger.warning( + "'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2." + ) self.config = { "DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}} diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 2e583e9bcb..d10396769e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -108,6 +108,10 @@ def __init__( model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked """ + LOGGER.warning( + "Parameter 'image' will be renamed to 'image_uri' in SageMaker Python SDK v2." + ) + self.model_data = model_data self.image = image self.role = role diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 28d1408ec4..1ab3fe2e27 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -21,6 +21,7 @@ framework_version_from_tag, empty_framework_version_warning, python_deprecation_warning, + parameter_v2_rename_warning, is_version_equal_or_higher, warn_if_parameter_server_with_multi_gpu, ) @@ -128,6 +129,7 @@ def __init__( ) if distributions is not None: + logger.warning(parameter_v2_rename_warning("distributions", "distribution")) train_instance_type = kwargs.get("train_instance_type") warn_if_parameter_server_with_multi_gpu( training_instance_type=train_instance_type, distributions=distributions diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index d81710c412..88b4f1a410 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -13,11 +13,27 @@ """This module contains Enums and helper methods related to S3.""" from __future__ import print_function, absolute_import +import logging import os from six.moves.urllib.parse import urlparse from sagemaker.session import Session +logger = logging.getLogger("sagemaker") + +SESSION_V2_RENAME_MESSAGE = ( + "Parameter 'session' will be renamed to 'sagemaker_session' in SageMaker Python SDK v2." +) + + +def _session_v2_rename_warning(session): + """ + Args: + session (sagemaker.session.Session): + """ + if session is not None: + logger.warning(SESSION_V2_RENAME_MESSAGE) + def parse_s3_url(url): """Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 @@ -54,6 +70,7 @@ def upload(local_path, desired_s3_uri, kms_key=None, session=None): The S3 uri of the uploaded file(s). """ + _session_v2_rename_warning(session) sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=desired_s3_uri) if kms_key is not None: @@ -80,6 +97,7 @@ def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, session= str: The S3 uri of the uploaded file(s). """ + _session_v2_rename_warning(session) sagemaker_session = session or Session() bucket, key = parse_s3_url(desired_s3_uri) @@ -107,6 +125,7 @@ def download(s3_uri, local_path, kms_key=None, session=None): using the default AWS configuration chain. """ + _session_v2_rename_warning(session) sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) if kms_key is not None: @@ -131,6 +150,7 @@ def read_file(s3_uri, session=None): str: The body of the file. """ + _session_v2_rename_warning(session) sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) @@ -149,6 +169,7 @@ def list(s3_uri, session=None): [str]: The list of S3 URIs in the given S3 base uri. """ + _session_v2_rename_warning(session) sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 0dc5bf3161..f03e40e97b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -181,6 +181,11 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): ``s3://{bucket name}/{key_prefix}``. """ # Generate a tuple for each file that we want to upload of the form (local_path, s3_key). + LOGGER.warning( + "'upload_data' method will be deprecated in favor of 'S3Uploader' class " + "in SageMaker Python SDK v2." + ) + files = [] key_suffix = None if os.path.isdir(path): @@ -230,6 +235,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None): str: The S3 URI of the uploaded file. The URI format is: ``s3://{bucket name}/{key}``. """ + LOGGER.warning( + "'upload_string_as_file_body' method will be deprecated in favor of 'S3Uploader' class " + "in SageMaker Python SDK v2." + ) + if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: @@ -3311,6 +3321,11 @@ def get_execution_role(sagemaker_session=None): Returns: (str): The role ARN """ + LOGGER.warning( + "'get_execution_role' will be renamed to 'notebook_execution_role' " + "in SageMaker Python SDK v2." + ) + if not sagemaker_session: sagemaker_session = Session() arn = sagemaker_session.get_caller_identity_arn() diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 47440ba21d..00b779b573 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -311,6 +311,7 @@ def __init__( ) if distributions is not None: + logger.warning(fw.parameter_v2_rename_warning("distribution", distributions)) train_instance_type = kwargs.get("train_instance_type") fw.warn_if_parameter_server_with_multi_gpu( training_instance_type=train_instance_type, distributions=distributions @@ -385,7 +386,9 @@ def _validate_args( if (not self._script_mode_enabled()) and self._only_script_mode_supported(): logger.warning( - "Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead." + "Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead. " + "Legacy mode and its training parameters will be deprecated in " + "SageMaker Python SDK v2. Please use TF 1.13 or higher and script mode." ) self.script_mode = True diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 1bad30dcb2..439e51067d 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -486,3 +486,12 @@ def test_is_latest_xgboost_version(): assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True + + +def test_get_image_uri_warn(caplog): + warning_message = ( + "'get_image_uri' method will be deprecated in favor of 'ImageURIProvider' class " + "in SageMaker Python SDK v2." + ) + get_image_uri("us-west-2", "kmeans", "latest") + assert warning_message in caplog.text diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index cd68501396..a4ae8a0da7 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -18,7 +18,7 @@ from sagemaker.inputs import FileSystemInput -def test_s3_input_all_defaults(): +def test_s3_input_all_defaults(caplog): prefix = "pre" actual = s3_input(s3_data=prefix) expected = { @@ -32,6 +32,11 @@ def test_s3_input_all_defaults(): } assert actual.config == expected + warning_message = ( + "'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2." + ) + assert warning_message in caplog.text + def test_s3_input_all_arguments(): prefix = "pre" diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 555b78b451..0013e4147d 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -82,7 +82,7 @@ def test_init_required_named(sagemaker_session): assert kmeans.k == ALL_REQ_ARGS["k"] -def test_all_hyperparameters(sagemaker_session): +def test_all_hyperparameters(sagemaker_session, caplog): kmeans = KMeans( sagemaker_session=sagemaker_session, init_method="random", @@ -110,6 +110,12 @@ def test_all_hyperparameters(sagemaker_session): force_dense="True", ) + warning_message = ( + "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " + "in SageMaker Python SDK v2." + ) + assert warning_message in caplog.text + def test_image(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 2736bc6736..84f8136389 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -675,7 +675,8 @@ def test_attach_custom_image(sagemaker_session): assert estimator.train_image() == training_image -def test_estimator_script_mode_launch_parameter_server(sagemaker_session): +@patch("sagemaker.mxnet.estimator.parameter_v2_rename_warning") +def test_estimator_script_mode_launch_parameter_server(warning, sagemaker_session): mx = MXNet( entry_point=SCRIPT_PATH, role=ROLE, @@ -686,6 +687,7 @@ def test_estimator_script_mode_launch_parameter_server(sagemaker_session): framework_version="1.3.0", ) assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == "true" + warning.assert_called_with("distributions", "distribution") def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session): @@ -844,7 +846,7 @@ def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session): assert mx.enable_sagemaker_metrics -def test_custom_image_estimator_deploy(sagemaker_session): +def test_custom_image_estimator_deploy(sagemaker_session, caplog): custom_image = "mycustomimage:latest" mx = MXNet( entry_point=SCRIPT_PATH, @@ -856,3 +858,6 @@ def test_custom_image_estimator_deploy(sagemaker_session): mx.fit(inputs="s3://mybucket/train", job_name="new_name") model = mx.create_model(image=custom_image) assert model.image == custom_image + + warning_message = "Parameter 'image' will be renamed to 'image_uri' in SageMaker Python SDK v2." + assert warning_message in caplog.text diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index d960e45f46..9ab9d5f603 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -89,7 +89,7 @@ def test_init_required_named(sagemaker_session): assert randomcutforest.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] -def test_all_hyperparameters(sagemaker_session): +def test_all_hyperparameters(sagemaker_session, caplog): randomcutforest = RandomCutForest( sagemaker_session=sagemaker_session, num_trees=NUM_TREES, @@ -102,6 +102,11 @@ def test_all_hyperparameters(sagemaker_session): num_trees=str(NUM_TREES), eval_metrics='["accuracy", "precision_recall_fscore"]', ) + warning_message = ( + "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " + "in SageMaker Python SDK v2." + ) + assert warning_message in caplog.text def test_image(sagemaker_session): diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index c073417116..12238e183b 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -40,7 +40,7 @@ def sagemaker_session(): return session_mock -def test_upload(sagemaker_session): +def test_upload(sagemaker_session, caplog): desired_s3_uri = os.path.join("s3://", BUCKET_NAME, CURRENT_JOB_NAME, SOURCE_NAME) S3Uploader.upload( local_path="/path/to/app.jar", desired_s3_uri=desired_s3_uri, session=sagemaker_session @@ -51,6 +51,10 @@ def test_upload(sagemaker_session): key_prefix=os.path.join(CURRENT_JOB_NAME, SOURCE_NAME), extra_args=None, ) + warning_message = ( + "Parameter 'session' will be renamed to 'sagemaker_session' " "in SageMaker Python SDK v2." + ) + assert warning_message in caplog.text def test_upload_with_kms_key(sagemaker_session): From ed1ee9555b4d876fe504bced472304c55deb2eb2 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 12 May 2020 01:13:25 -0700 Subject: [PATCH 02/14] fix flake8 error --- tests/unit/test_amazon_estimator.py | 8 ++++---- tests/unit/test_fw_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 439e51067d..1c0a2fbd10 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -228,16 +228,16 @@ def test_fit_ndarray(time, sagemaker_session): labels = [99, 85, 87, 2] pca.fit(pca.record_set(np.array(train), np.array(labels))) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest" ) assert mock_object.put.call_count == 4 diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 9eee574cf2..1c6388c38b 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -735,7 +735,7 @@ def test_invalid_framework_accelerator(): def test_invalid_framework_accelerator_with_neo(): - error_message = "Neo does not support Amazon Elastic Inference.".format(MOCK_FRAMEWORK) + error_message = "Neo does not support Amazon Elastic Inference." # accelerator was chosen for unsupported framework with pytest.raises(ValueError) as error: fw_utils.create_image_uri( From f7f0ac6750649af9abb0e2ad2ec37b06cad12002 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 14 May 2020 13:10:03 -0700 Subject: [PATCH 03/14] change: create ASTTransformer class to handle migrating Python SDK code for v2 (#1492) As a start, this class ensures that the framework_version parameter is specified when framework classes are instantiated. --- tools/__init__.py | 13 ++ tools/compatibility/__init__.py | 13 ++ tools/compatibility/v2/__init__.py | 13 ++ tools/compatibility/v2/ast_transformer.py | 41 ++++++ tools/compatibility/v2/modifiers/__init__.py | 14 ++ .../v2/modifiers/framework_version.py | 123 ++++++++++++++++++ tools/compatibility/v2/modifiers/modifier.py | 35 +++++ 7 files changed, 252 insertions(+) create mode 100644 tools/__init__.py create mode 100644 tools/compatibility/__init__.py create mode 100644 tools/compatibility/v2/__init__.py create mode 100644 tools/compatibility/v2/ast_transformer.py create mode 100644 tools/compatibility/v2/modifiers/__init__.py create mode 100644 tools/compatibility/v2/modifiers/framework_version.py create mode 100644 tools/compatibility/v2/modifiers/modifier.py diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000..ec1e80a0b4 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tools/compatibility/__init__.py b/tools/compatibility/__init__.py new file mode 100644 index 0000000000..ec1e80a0b4 --- /dev/null +++ b/tools/compatibility/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tools/compatibility/v2/__init__.py b/tools/compatibility/v2/__init__.py new file mode 100644 index 0000000000..ec1e80a0b4 --- /dev/null +++ b/tools/compatibility/v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tools/compatibility/v2/ast_transformer.py b/tools/compatibility/v2/ast_transformer.py new file mode 100644 index 0000000000..87d7dddcb7 --- /dev/null +++ b/tools/compatibility/v2/ast_transformer.py @@ -0,0 +1,41 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code.""" +from __future__ import absolute_import + +import ast + +from modifiers import framework_version + +FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()] + + +class ASTTransformer(ast.NodeTransformer): + """An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and + modifies nodes to upgrade the given SageMaker Python SDK code. + """ + + def visit_Call(self, node): + """Visits an ``ast.Call`` node and returns a modified node, if needed. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. + + Args: + node (ast.Call): a node that represents a function call. + + Returns: + ast.Call: a node that represents a function call, which has + potentially been modified from the original input. + """ + for function_checker in FUNCTION_CALL_MODIFIERS: + function_checker.check_and_modify_node(node) + return node diff --git a/tools/compatibility/v2/modifiers/__init__.py b/tools/compatibility/v2/modifiers/__init__.py new file mode 100644 index 0000000000..9fca9c35da --- /dev/null +++ b/tools/compatibility/v2/modifiers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes for modifying AST nodes""" +from __future__ import absolute_import diff --git a/tools/compatibility/v2/modifiers/framework_version.py b/tools/compatibility/v2/modifiers/framework_version.py new file mode 100644 index 0000000000..fa771c061f --- /dev/null +++ b/tools/compatibility/v2/modifiers/framework_version.py @@ -0,0 +1,123 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A class to ensure that ``framework_version`` is defined when constructing framework classes.""" +from __future__ import absolute_import + +import ast + +from modifiers.modifier import Modifier + +FRAMEWORK_DEFAULTS = { + "Chainer": "4.1.0", + "MXNet": "1.2.0", + "PyTorch": "0.4.0", + "SKLearn": "0.20.0", + "TensorFlow": "1.11.0", +} + +FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys()) +# TODO: check for sagemaker.tensorflow.serving.Model +FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS] +FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS] + + +class FrameworkVersionEnforcer(Modifier): + def node_should_be_modified(self, node): + """Check if the ast.Call node instantiates a framework estimator or model, + but doesn't specify the framework_version parameter. + + This looks for the following formats: + + - ``TensorFlow`` + - ``sagemaker.tensorflow.TensorFlow`` + + where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` is instantiating a framework class that + should specify ``framework_version``, but doesn't. + """ + if self._is_framework_constructor(node): + return not self._fw_version_in_keywords(node) + + return False + + def _is_framework_constructor(self, node): + """Check if the ``ast.Call`` node represents a call of the form + or sagemaker... + """ + if isinstance(node.func, ast.Name): + if node.func.id in FRAMEWORK_CLASSES: + return True + + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr in FRAMEWORK_CLASSES + and isinstance(node.func.value, ast.Attribute) + and node.func.value.attr in FRAMEWORK_MODULES + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "sagemaker" + ): + return True + + return False + + def _fw_version_in_keywords(self, node): + """Check if the ``ast.Call`` node's keywords contain ``framework_version``.""" + for kw in node.keywords: + if kw.arg == "framework_version" and kw.value: + return True + return False + + def modify_node(self, node): + """Modify the ``ast.Call`` node's keywords to include ``framework_version``. + + The ``framework_version`` value is determined by the framework: + + - Chainer: "4.1.0" + - MXNet: "1.2.0" + - PyTorch: "0.4.0" + - SKLearn: "0.20.0" + - TensorFlow: "1.11.0" + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + """ + framework = self._framework_name_from_node(node) + node.keywords.append( + ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework])) + ) + + def _framework_name_from_node(self, node): + """Retrieve the framework name based on the function call. + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + This can represent either or sagemaker... + + Returns: + str: the (capitalized) framework name. + """ + if isinstance(node.func, ast.Name): + framework = node.func.id + elif isinstance(node.func, ast.Attribute): + framework = node.func.attr + + if framework.endswith("Model"): + framework = framework[: framework.find("Model")] + + return framework diff --git a/tools/compatibility/v2/modifiers/modifier.py b/tools/compatibility/v2/modifiers/modifier.py new file mode 100644 index 0000000000..c1d53dfc85 --- /dev/null +++ b/tools/compatibility/v2/modifiers/modifier.py @@ -0,0 +1,35 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Abstract class for modifying AST nodes.""" +from __future__ import absolute_import + +from abc import abstractmethod + + +class Modifier(object): + """Abstract class to take in an AST node, check if it needs modification, + and potentially modify the node. + """ + + def check_and_modify_node(self, node): + """Check an AST node, and modify it if applicable.""" + if self.node_should_be_modified(node): + self.modify_node(node) + + @abstractmethod + def node_should_be_modified(self, node): + """Check if an AST node should be modified.""" + + @abstractmethod + def modify_node(self, node): + """Modify an AST node.""" From 81ad62b4e8ff3c1559b54790cb69388139c1df87 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 15 May 2020 10:13:03 -0700 Subject: [PATCH 04/14] change: add class to read Python scripts and update code for v2 (#1497) --- tools/compatibility/v2/files.py | 86 +++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tools/compatibility/v2/files.py diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py new file mode 100644 index 0000000000..f8f885798a --- /dev/null +++ b/tools/compatibility/v2/files.py @@ -0,0 +1,86 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes for updating code in files.""" +from __future__ import absolute_import + +import os +import logging + +import pasta + +from ast_transformer import ASTTransformer + +LOGGER = logging.getLogger(__name__) + + +class PyFileUpdater(object): + """A class for updating Python (``*.py``) files.""" + + def __init__(self, input_path, output_path): + """Creates a ``PyFileUpdater`` for updating a Python file so that + it is compatible with v2 of the SageMaker Python SDK. + + Args: + input_path (str): Location of the input file. + output_path (str): Desired location for the output file. + If the directories don't already exist, then they are created. + If a file exists at ``output_path``, then it is overwritten. + """ + self.input_path = input_path + self.output_path = output_path + + def update(self): + """Reads the input Python file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + output = self._update_ast(self._read_input_file()) + self._write_output_file(output) + + def _update_ast(self, input_ast): + """Updates an abstract syntax tree (AST) so that it is compatible + with v2 of the SageMaker Python SDK. + + Args: + input_ast (ast.Module): AST to be updated for use with Python SDK v2. + + Returns: + ast.Module: Updated AST that is compatible with Python SDK v2. + """ + return ASTTransformer().visit(input_ast) + + def _read_input_file(self): + """Reads input file and parse as an abstract syntax tree (AST). + + Returns: + ast.Module: AST representation of the input file. + """ + with open(self.input_path) as input_file: + return pasta.parse(input_file.read()) + + def _write_output_file(self, output): + """Writes abstract syntax tree (AST) to output file. + Creates the directories for the output path, if needed. + + Args: + output (ast.Module): AST to save as the output file. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file {}".format(self.output_path)) + + with open(self.output_path, "w") as output_file: + output_file.write(pasta.dump(output)) From d197b7456ed9aee68c6c026f13c0ce731db03804 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 15 May 2020 14:48:55 -0700 Subject: [PATCH 05/14] infra: add tools/ dir to pylint check (#1499) --- tools/__init__.py | 1 + tools/compatibility/__init__.py | 1 + tools/compatibility/v2/__init__.py | 1 + tools/compatibility/v2/files.py | 2 +- .../v2/modifiers/framework_version.py | 33 +++++++++++-------- tox.ini | 2 +- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/tools/__init__.py b/tools/__init__.py index ec1e80a0b4..96abea2567 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -10,4 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tools to assist with using the SageMake Python SDK.""" from __future__ import absolute_import diff --git a/tools/compatibility/__init__.py b/tools/compatibility/__init__.py index ec1e80a0b4..e3a46fe406 100644 --- a/tools/compatibility/__init__.py +++ b/tools/compatibility/__init__.py @@ -10,4 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tools to assist with compatibility between SageMaker Python SDK versions.""" from __future__ import absolute_import diff --git a/tools/compatibility/v2/__init__.py b/tools/compatibility/v2/__init__.py index ec1e80a0b4..b44e22749e 100644 --- a/tools/compatibility/v2/__init__.py +++ b/tools/compatibility/v2/__init__.py @@ -10,4 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tools to assist with upgrading to v2 of the SageMaker Python SDK.""" from __future__ import absolute_import diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py index f8f885798a..055e30a1c5 100644 --- a/tools/compatibility/v2/files.py +++ b/tools/compatibility/v2/files.py @@ -80,7 +80,7 @@ def _write_output_file(self, output): os.makedirs(output_dir) if os.path.exists(self.output_path): - LOGGER.warning("Overwriting file {}".format(self.output_path)) + LOGGER.warning("Overwriting file %s", self.output_path) with open(self.output_path, "w") as output_file: output_file.write(pasta.dump(output)) diff --git a/tools/compatibility/v2/modifiers/framework_version.py b/tools/compatibility/v2/modifiers/framework_version.py index fa771c061f..2c2a440ba7 100644 --- a/tools/compatibility/v2/modifiers/framework_version.py +++ b/tools/compatibility/v2/modifiers/framework_version.py @@ -32,9 +32,13 @@ class FrameworkVersionEnforcer(Modifier): + """A class to ensure that ``framework_version`` is defined when + instantiating a framework estimator or model. + """ + def node_should_be_modified(self, node): - """Check if the ast.Call node instantiates a framework estimator or model, - but doesn't specify the framework_version parameter. + """Checks if the ast.Call node instantiates a framework estimator or model, + but doesn't specify the ``framework_version`` parameter. This looks for the following formats: @@ -57,34 +61,37 @@ def node_should_be_modified(self, node): return False def _is_framework_constructor(self, node): - """Check if the ``ast.Call`` node represents a call of the form + """Checks if the ``ast.Call`` node represents a call of the form or sagemaker... """ + # Check for call if isinstance(node.func, ast.Name): if node.func.id in FRAMEWORK_CLASSES: return True - if ( - isinstance(node.func, ast.Attribute) - and node.func.attr in FRAMEWORK_CLASSES - and isinstance(node.func.value, ast.Attribute) + # Check for sagemaker.. call + ends_with_framework_constructor = ( + isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES + ) + + is_in_framework_module = ( + isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_MODULES and isinstance(node.func.value.value, ast.Name) and node.func.value.value.id == "sagemaker" - ): - return True + ) - return False + return ends_with_framework_constructor and is_in_framework_module def _fw_version_in_keywords(self, node): - """Check if the ``ast.Call`` node's keywords contain ``framework_version``.""" + """Checks if the ``ast.Call`` node's keywords contain ``framework_version``.""" for kw in node.keywords: if kw.arg == "framework_version" and kw.value: return True return False def modify_node(self, node): - """Modify the ``ast.Call`` node's keywords to include ``framework_version``. + """Modifies the ``ast.Call`` node's keywords to include ``framework_version``. The ``framework_version`` value is determined by the framework: @@ -103,7 +110,7 @@ def modify_node(self, node): ) def _framework_name_from_node(self, node): - """Retrieve the framework name based on the function call. + """Retrieves the framework name based on the function call. Args: node (ast.Call): a node that represents the constructor of a framework class. diff --git a/tox.ini b/tox.ini index 78cd0f1d4a..785ef17582 100644 --- a/tox.ini +++ b/tox.ini @@ -82,7 +82,7 @@ skip_install = true deps = pylint==2.3.1 commands = - python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker + python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker tools [testenv:twine] basepython = python3 From 88518e088ebf5883a15a2d91293a57c2ac924d14 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Mon, 18 May 2020 10:05:48 -0700 Subject: [PATCH 06/14] change: add CLI wrapper for v2 migration script (#1500) --- .../compatibility/v2/sagemaker_upgrade_v2.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tools/compatibility/v2/sagemaker_upgrade_v2.py diff --git a/tools/compatibility/v2/sagemaker_upgrade_v2.py b/tools/compatibility/v2/sagemaker_upgrade_v2.py new file mode 100644 index 0000000000..04b5fad876 --- /dev/null +++ b/tools/compatibility/v2/sagemaker_upgrade_v2.py @@ -0,0 +1,43 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A tool to upgrade SageMaker Python SDK code to be compatible with v2.""" +from __future__ import absolute_import + +import argparse + +import files + + +def _parse_and_validate_args(): + """Parses CLI arguments""" + parser = argparse.ArgumentParser( + description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK." + "\nSimple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" + ) + parser.add_argument( + "--in-file", help="If converting a single file, the name of the file to convert" + ) + parser.add_argument( + "--out-file", + help="If converting a single file, the output file destination. If needed, " + "directories in the output file path are created. If the output file already exists, " + "it is overwritten.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_and_validate_args() + + files.PyFileUpdater(input_path=args.in_file, output_path=args.out_file).update() From 57b2a224af44c346942f171da420e3586f870c55 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 19 May 2020 11:06:45 -0700 Subject: [PATCH 07/14] change: add .ipynb file support for v2 migration script (#1508) --- tools/compatibility/v2/files.py | 104 +++++++++++++++++- .../compatibility/v2/sagemaker_upgrade_v2.py | 53 +++++++-- 2 files changed, 142 insertions(+), 15 deletions(-) diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py index 055e30a1c5..b385274093 100644 --- a/tools/compatibility/v2/files.py +++ b/tools/compatibility/v2/files.py @@ -13,8 +13,10 @@ """Classes for updating code in files.""" from __future__ import absolute_import -import os +from abc import abstractmethod +import json import logging +import os import pasta @@ -23,11 +25,11 @@ LOGGER = logging.getLogger(__name__) -class PyFileUpdater(object): - """A class for updating Python (``*.py``) files.""" +class FileUpdater(object): + """An abstract class for updating files.""" def __init__(self, input_path, output_path): - """Creates a ``PyFileUpdater`` for updating a Python file so that + """Creates a ``FileUpdater`` for updating a file so that it is compatible with v2 of the SageMaker Python SDK. Args: @@ -39,6 +41,17 @@ def __init__(self, input_path, output_path): self.input_path = input_path self.output_path = output_path + @abstractmethod + def update(self): + """Reads the input file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + + +class PyFileUpdater(FileUpdater): + """A class for updating Python (``*.py``) files.""" + def update(self): """Reads the input Python file, updates the code so that it is compatible with v2 of the SageMaker Python SDK, and writes the @@ -60,7 +73,7 @@ def _update_ast(self, input_ast): return ASTTransformer().visit(input_ast) def _read_input_file(self): - """Reads input file and parse as an abstract syntax tree (AST). + """Reads input file and parses it as an abstract syntax tree (AST). Returns: ast.Module: AST representation of the input file. @@ -84,3 +97,84 @@ def _write_output_file(self, output): with open(self.output_path, "w") as output_file: output_file.write(pasta.dump(output)) + + +class JupyterNotebookFileUpdater(FileUpdater): + """A class for updating Jupyter notebook (``*.ipynb``) files. + + For more on this file format, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat. + """ + + def update(self): + """Reads the input Jupyter notebook file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + nb_json = self._read_input_file() + for cell in nb_json["cells"]: + if cell["cell_type"] == "code": + updated_source = self._update_code_from_cell(cell) + cell["source"] = updated_source + + self._write_output_file(nb_json) + + def _update_code_from_cell(self, cell): + """Updates the code from a code cell so that it is + compatible with v2 of the SageMaker Python SDK. + + Args: + cell (dict): A dictionary representation of a code cell from + a Jupyter notebook. For more info, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells. + + Returns: + list[str]: A list of strings containing the lines of updated code that + can be used for the "source" attribute of a Jupyter notebook code cell. + """ + code = "".join(cell["source"]) + updated_ast = ASTTransformer().visit(pasta.parse(code)) + updated_code = pasta.dump(updated_ast) + return self._code_str_to_source_list(updated_code) + + def _code_str_to_source_list(self, code): + """Converts a string of code into a list for a Jupyter notebook code cell. + + Args: + code (str): Code to be converted. + + Returns: + list[str]: A list of strings containing the lines of code that + can be used for the "source" attribute of a Jupyter notebook code cell. + Each element of the list (i.e. line of code) contains a + trailing newline character ("\n") except for the last element. + """ + source_list = ["{}\n".format(s) for s in code.split("\n")] + source_list[-1] = source_list[-1].rstrip("\n") + return source_list + + def _read_input_file(self): + """Reads input file and parses it as JSON. + + Returns: + dict: JSON representation of the input file. + """ + with open(self.input_path) as input_file: + return json.load(input_file) + + def _write_output_file(self, output): + """Writes JSON to output file. Creates the directories for the output path, if needed. + + Args: + output (dict): JSON to save as the output file. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file %s", self.output_path) + + with open(self.output_path, "w") as output_file: + json.dump(output, output_file, indent=1) + output_file.write("\n") # json.dump does not write trailing newline diff --git a/tools/compatibility/v2/sagemaker_upgrade_v2.py b/tools/compatibility/v2/sagemaker_upgrade_v2.py index 04b5fad876..2238775e1a 100644 --- a/tools/compatibility/v2/sagemaker_upgrade_v2.py +++ b/tools/compatibility/v2/sagemaker_upgrade_v2.py @@ -14,30 +14,63 @@ from __future__ import absolute_import import argparse +import os import files +_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater} -def _parse_and_validate_args(): + +def _update_file(input_file, output_file): + """Update a file to be compatible with v2 of the SageMaker Python SDK, + and write the updated source to the output file. + + Args: + input_file (str): The path to the file to be updated. + output_file (str): The output file destination. + + Raises: + ValueError: If the input and output filename extensions don't match, + or if the file extensions are neither ".py" nor ".ipynb". + """ + input_file_ext = os.path.splitext(input_file)[1] + output_file_ext = os.path.splitext(output_file)[1] + + if input_file_ext != output_file_ext: + raise ValueError( + "Mismatched file extensions: input: {}, output: {}".format( + input_file_ext, output_file_ext + ) + ) + + if input_file_ext not in _EXT_TO_UPDATER_CLS: + raise ValueError("Unrecognized file extension: {}".format(input_file_ext)) + + updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext] + updater_cls(input_path=input_file, output_path=output_file).update() + + +def _parse_args(): """Parses CLI arguments""" parser = argparse.ArgumentParser( - description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK." - "\nSimple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" + description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. " + "Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" ) parser.add_argument( - "--in-file", help="If converting a single file, the name of the file to convert" + "--in-file", + help="If converting a single file, the file to convert. The file's extension " + "must be either '.py' or '.ipynb'.", ) parser.add_argument( "--out-file", - help="If converting a single file, the output file destination. If needed, " - "directories in the output file path are created. If the output file already exists, " - "it is overwritten.", + help="If converting a single file, the output file destination. The file's extension " + "must be either '.py' or '.ipynb'. If needed, directories in the output path are created. " + "If the output file already exists, it is overwritten.", ) return parser.parse_args() if __name__ == "__main__": - args = _parse_and_validate_args() - - files.PyFileUpdater(input_path=args.in_file, output_path=args.out_file).update() + args = _parse_args() + _update_file(args.in_file, args.out_file) From 07014776b0873ed172272228afa0a5d9000a068a Mon Sep 17 00:00:00 2001 From: Chuyang Date: Tue, 19 May 2020 16:34:53 -0700 Subject: [PATCH 08/14] Revert "update with aws zwei" --- tools/__init__.py | 14 -- tools/compatibility/__init__.py | 14 -- tools/compatibility/v2/__init__.py | 14 -- tools/compatibility/v2/ast_transformer.py | 41 ---- tools/compatibility/v2/files.py | 180 ------------------ tools/compatibility/v2/modifiers/__init__.py | 14 -- .../v2/modifiers/framework_version.py | 130 ------------- tools/compatibility/v2/modifiers/modifier.py | 35 ---- .../compatibility/v2/sagemaker_upgrade_v2.py | 76 -------- tox.ini | 2 +- 10 files changed, 1 insertion(+), 519 deletions(-) delete mode 100644 tools/__init__.py delete mode 100644 tools/compatibility/__init__.py delete mode 100644 tools/compatibility/v2/__init__.py delete mode 100644 tools/compatibility/v2/ast_transformer.py delete mode 100644 tools/compatibility/v2/files.py delete mode 100644 tools/compatibility/v2/modifiers/__init__.py delete mode 100644 tools/compatibility/v2/modifiers/framework_version.py delete mode 100644 tools/compatibility/v2/modifiers/modifier.py delete mode 100644 tools/compatibility/v2/sagemaker_upgrade_v2.py diff --git a/tools/__init__.py b/tools/__init__.py deleted file mode 100644 index 96abea2567..0000000000 --- a/tools/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tools to assist with using the SageMake Python SDK.""" -from __future__ import absolute_import diff --git a/tools/compatibility/__init__.py b/tools/compatibility/__init__.py deleted file mode 100644 index e3a46fe406..0000000000 --- a/tools/compatibility/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tools to assist with compatibility between SageMaker Python SDK versions.""" -from __future__ import absolute_import diff --git a/tools/compatibility/v2/__init__.py b/tools/compatibility/v2/__init__.py deleted file mode 100644 index b44e22749e..0000000000 --- a/tools/compatibility/v2/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tools to assist with upgrading to v2 of the SageMaker Python SDK.""" -from __future__ import absolute_import diff --git a/tools/compatibility/v2/ast_transformer.py b/tools/compatibility/v2/ast_transformer.py deleted file mode 100644 index 87d7dddcb7..0000000000 --- a/tools/compatibility/v2/ast_transformer.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code.""" -from __future__ import absolute_import - -import ast - -from modifiers import framework_version - -FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()] - - -class ASTTransformer(ast.NodeTransformer): - """An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and - modifies nodes to upgrade the given SageMaker Python SDK code. - """ - - def visit_Call(self, node): - """Visits an ``ast.Call`` node and returns a modified node, if needed. - See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. - - Args: - node (ast.Call): a node that represents a function call. - - Returns: - ast.Call: a node that represents a function call, which has - potentially been modified from the original input. - """ - for function_checker in FUNCTION_CALL_MODIFIERS: - function_checker.check_and_modify_node(node) - return node diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py deleted file mode 100644 index b385274093..0000000000 --- a/tools/compatibility/v2/files.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Classes for updating code in files.""" -from __future__ import absolute_import - -from abc import abstractmethod -import json -import logging -import os - -import pasta - -from ast_transformer import ASTTransformer - -LOGGER = logging.getLogger(__name__) - - -class FileUpdater(object): - """An abstract class for updating files.""" - - def __init__(self, input_path, output_path): - """Creates a ``FileUpdater`` for updating a file so that - it is compatible with v2 of the SageMaker Python SDK. - - Args: - input_path (str): Location of the input file. - output_path (str): Desired location for the output file. - If the directories don't already exist, then they are created. - If a file exists at ``output_path``, then it is overwritten. - """ - self.input_path = input_path - self.output_path = output_path - - @abstractmethod - def update(self): - """Reads the input file, updates the code so that it is - compatible with v2 of the SageMaker Python SDK, and writes the - updated code to an output file. - """ - - -class PyFileUpdater(FileUpdater): - """A class for updating Python (``*.py``) files.""" - - def update(self): - """Reads the input Python file, updates the code so that it is - compatible with v2 of the SageMaker Python SDK, and writes the - updated code to an output file. - """ - output = self._update_ast(self._read_input_file()) - self._write_output_file(output) - - def _update_ast(self, input_ast): - """Updates an abstract syntax tree (AST) so that it is compatible - with v2 of the SageMaker Python SDK. - - Args: - input_ast (ast.Module): AST to be updated for use with Python SDK v2. - - Returns: - ast.Module: Updated AST that is compatible with Python SDK v2. - """ - return ASTTransformer().visit(input_ast) - - def _read_input_file(self): - """Reads input file and parses it as an abstract syntax tree (AST). - - Returns: - ast.Module: AST representation of the input file. - """ - with open(self.input_path) as input_file: - return pasta.parse(input_file.read()) - - def _write_output_file(self, output): - """Writes abstract syntax tree (AST) to output file. - Creates the directories for the output path, if needed. - - Args: - output (ast.Module): AST to save as the output file. - """ - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) - - if os.path.exists(self.output_path): - LOGGER.warning("Overwriting file %s", self.output_path) - - with open(self.output_path, "w") as output_file: - output_file.write(pasta.dump(output)) - - -class JupyterNotebookFileUpdater(FileUpdater): - """A class for updating Jupyter notebook (``*.ipynb``) files. - - For more on this file format, see - https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat. - """ - - def update(self): - """Reads the input Jupyter notebook file, updates the code so that it is - compatible with v2 of the SageMaker Python SDK, and writes the - updated code to an output file. - """ - nb_json = self._read_input_file() - for cell in nb_json["cells"]: - if cell["cell_type"] == "code": - updated_source = self._update_code_from_cell(cell) - cell["source"] = updated_source - - self._write_output_file(nb_json) - - def _update_code_from_cell(self, cell): - """Updates the code from a code cell so that it is - compatible with v2 of the SageMaker Python SDK. - - Args: - cell (dict): A dictionary representation of a code cell from - a Jupyter notebook. For more info, see - https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells. - - Returns: - list[str]: A list of strings containing the lines of updated code that - can be used for the "source" attribute of a Jupyter notebook code cell. - """ - code = "".join(cell["source"]) - updated_ast = ASTTransformer().visit(pasta.parse(code)) - updated_code = pasta.dump(updated_ast) - return self._code_str_to_source_list(updated_code) - - def _code_str_to_source_list(self, code): - """Converts a string of code into a list for a Jupyter notebook code cell. - - Args: - code (str): Code to be converted. - - Returns: - list[str]: A list of strings containing the lines of code that - can be used for the "source" attribute of a Jupyter notebook code cell. - Each element of the list (i.e. line of code) contains a - trailing newline character ("\n") except for the last element. - """ - source_list = ["{}\n".format(s) for s in code.split("\n")] - source_list[-1] = source_list[-1].rstrip("\n") - return source_list - - def _read_input_file(self): - """Reads input file and parses it as JSON. - - Returns: - dict: JSON representation of the input file. - """ - with open(self.input_path) as input_file: - return json.load(input_file) - - def _write_output_file(self, output): - """Writes JSON to output file. Creates the directories for the output path, if needed. - - Args: - output (dict): JSON to save as the output file. - """ - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) - - if os.path.exists(self.output_path): - LOGGER.warning("Overwriting file %s", self.output_path) - - with open(self.output_path, "w") as output_file: - json.dump(output, output_file, indent=1) - output_file.write("\n") # json.dump does not write trailing newline diff --git a/tools/compatibility/v2/modifiers/__init__.py b/tools/compatibility/v2/modifiers/__init__.py deleted file mode 100644 index 9fca9c35da..0000000000 --- a/tools/compatibility/v2/modifiers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Classes for modifying AST nodes""" -from __future__ import absolute_import diff --git a/tools/compatibility/v2/modifiers/framework_version.py b/tools/compatibility/v2/modifiers/framework_version.py deleted file mode 100644 index 2c2a440ba7..0000000000 --- a/tools/compatibility/v2/modifiers/framework_version.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""A class to ensure that ``framework_version`` is defined when constructing framework classes.""" -from __future__ import absolute_import - -import ast - -from modifiers.modifier import Modifier - -FRAMEWORK_DEFAULTS = { - "Chainer": "4.1.0", - "MXNet": "1.2.0", - "PyTorch": "0.4.0", - "SKLearn": "0.20.0", - "TensorFlow": "1.11.0", -} - -FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys()) -# TODO: check for sagemaker.tensorflow.serving.Model -FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS] -FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS] - - -class FrameworkVersionEnforcer(Modifier): - """A class to ensure that ``framework_version`` is defined when - instantiating a framework estimator or model. - """ - - def node_should_be_modified(self, node): - """Checks if the ast.Call node instantiates a framework estimator or model, - but doesn't specify the ``framework_version`` parameter. - - This looks for the following formats: - - - ``TensorFlow`` - - ``sagemaker.tensorflow.TensorFlow`` - - where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. - - Args: - node (ast.Call): a node that represents a function call. For more, - see https://docs.python.org/3/library/ast.html#abstract-grammar. - - Returns: - bool: If the ``ast.Call`` is instantiating a framework class that - should specify ``framework_version``, but doesn't. - """ - if self._is_framework_constructor(node): - return not self._fw_version_in_keywords(node) - - return False - - def _is_framework_constructor(self, node): - """Checks if the ``ast.Call`` node represents a call of the form - or sagemaker... - """ - # Check for call - if isinstance(node.func, ast.Name): - if node.func.id in FRAMEWORK_CLASSES: - return True - - # Check for sagemaker.. call - ends_with_framework_constructor = ( - isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES - ) - - is_in_framework_module = ( - isinstance(node.func.value, ast.Attribute) - and node.func.value.attr in FRAMEWORK_MODULES - and isinstance(node.func.value.value, ast.Name) - and node.func.value.value.id == "sagemaker" - ) - - return ends_with_framework_constructor and is_in_framework_module - - def _fw_version_in_keywords(self, node): - """Checks if the ``ast.Call`` node's keywords contain ``framework_version``.""" - for kw in node.keywords: - if kw.arg == "framework_version" and kw.value: - return True - return False - - def modify_node(self, node): - """Modifies the ``ast.Call`` node's keywords to include ``framework_version``. - - The ``framework_version`` value is determined by the framework: - - - Chainer: "4.1.0" - - MXNet: "1.2.0" - - PyTorch: "0.4.0" - - SKLearn: "0.20.0" - - TensorFlow: "1.11.0" - - Args: - node (ast.Call): a node that represents the constructor of a framework class. - """ - framework = self._framework_name_from_node(node) - node.keywords.append( - ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework])) - ) - - def _framework_name_from_node(self, node): - """Retrieves the framework name based on the function call. - - Args: - node (ast.Call): a node that represents the constructor of a framework class. - This can represent either or sagemaker... - - Returns: - str: the (capitalized) framework name. - """ - if isinstance(node.func, ast.Name): - framework = node.func.id - elif isinstance(node.func, ast.Attribute): - framework = node.func.attr - - if framework.endswith("Model"): - framework = framework[: framework.find("Model")] - - return framework diff --git a/tools/compatibility/v2/modifiers/modifier.py b/tools/compatibility/v2/modifiers/modifier.py deleted file mode 100644 index c1d53dfc85..0000000000 --- a/tools/compatibility/v2/modifiers/modifier.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Abstract class for modifying AST nodes.""" -from __future__ import absolute_import - -from abc import abstractmethod - - -class Modifier(object): - """Abstract class to take in an AST node, check if it needs modification, - and potentially modify the node. - """ - - def check_and_modify_node(self, node): - """Check an AST node, and modify it if applicable.""" - if self.node_should_be_modified(node): - self.modify_node(node) - - @abstractmethod - def node_should_be_modified(self, node): - """Check if an AST node should be modified.""" - - @abstractmethod - def modify_node(self, node): - """Modify an AST node.""" diff --git a/tools/compatibility/v2/sagemaker_upgrade_v2.py b/tools/compatibility/v2/sagemaker_upgrade_v2.py deleted file mode 100644 index 2238775e1a..0000000000 --- a/tools/compatibility/v2/sagemaker_upgrade_v2.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""A tool to upgrade SageMaker Python SDK code to be compatible with v2.""" -from __future__ import absolute_import - -import argparse -import os - -import files - -_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater} - - -def _update_file(input_file, output_file): - """Update a file to be compatible with v2 of the SageMaker Python SDK, - and write the updated source to the output file. - - Args: - input_file (str): The path to the file to be updated. - output_file (str): The output file destination. - - Raises: - ValueError: If the input and output filename extensions don't match, - or if the file extensions are neither ".py" nor ".ipynb". - """ - input_file_ext = os.path.splitext(input_file)[1] - output_file_ext = os.path.splitext(output_file)[1] - - if input_file_ext != output_file_ext: - raise ValueError( - "Mismatched file extensions: input: {}, output: {}".format( - input_file_ext, output_file_ext - ) - ) - - if input_file_ext not in _EXT_TO_UPDATER_CLS: - raise ValueError("Unrecognized file extension: {}".format(input_file_ext)) - - updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext] - updater_cls(input_path=input_file, output_path=output_file).update() - - -def _parse_args(): - """Parses CLI arguments""" - parser = argparse.ArgumentParser( - description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. " - "Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" - ) - parser.add_argument( - "--in-file", - help="If converting a single file, the file to convert. The file's extension " - "must be either '.py' or '.ipynb'.", - ) - parser.add_argument( - "--out-file", - help="If converting a single file, the output file destination. The file's extension " - "must be either '.py' or '.ipynb'. If needed, directories in the output path are created. " - "If the output file already exists, it is overwritten.", - ) - - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - _update_file(args.in_file, args.out_file) diff --git a/tox.ini b/tox.ini index 785ef17582..78cd0f1d4a 100644 --- a/tox.ini +++ b/tox.ini @@ -82,7 +82,7 @@ skip_install = true deps = pylint==2.3.1 commands = - python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker tools + python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker [testenv:twine] basepython = python3 From dbeaa953d8d31bbde860d5facf373327181b8be3 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 12 May 2020 01:13:25 -0700 Subject: [PATCH 09/14] fix flake8 error --- CHANGELOG.md | 23 +++ VERSION | 2 +- doc/using_pytorch.rst | 137 +++++++++--------- src/sagemaker/estimator.py | 14 +- src/sagemaker/model.py | 15 +- src/sagemaker/model_monitor/dataset_format.py | 2 +- src/sagemaker/mxnet/estimator.py | 9 +- src/sagemaker/processing.py | 6 +- src/sagemaker/pytorch/estimator.py | 9 +- src/sagemaker/rl/estimator.py | 9 +- src/sagemaker/session.py | 12 ++ src/sagemaker/sklearn/estimator.py | 9 +- src/sagemaker/tensorflow/estimator.py | 11 +- src/sagemaker/xgboost/estimator.py | 8 +- tests/integ/test_transformer.py | 12 +- tests/unit/test_amazon_estimator.py | 8 +- tests/unit/test_fw_utils.py | 2 +- 17 files changed, 167 insertions(+), 121 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f5bfbbb14..42d2f5d936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## v1.58.3 (2020-05-19) + +### Bug Fixes and Other Changes + + * update DatasetFormat key name for sagemakerCaptureJson + +### Documentation Changes + + * update Processing job max_runtime_in_seconds docstring + +## v1.58.2.post0 (2020-05-18) + +### Documentation Changes + + * specify S3 source_dir needs to point to a tar file + * update PyTorch BYOM topic + +## v1.58.2 (2020-05-13) + +### Bug Fixes and Other Changes + + * address flake8 error + ## v1.58.1 (2020-05-11) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 130f9238df..d8fe8455df 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.58.2.dev0 +1.58.4.dev0 diff --git a/doc/using_pytorch.rst b/doc/using_pytorch.rst index a27b2c6f5d..995e030067 100644 --- a/doc/using_pytorch.rst +++ b/doc/using_pytorch.rst @@ -90,7 +90,7 @@ Note that SageMaker doesn't support argparse actions. If you want to use, for ex you need to specify `type` as `bool` in your script and provide an explicit `True` or `False` value for this hyperparameter when instantiating PyTorch Estimator. -For more on training environment variables, please visit `SageMaker Containers `_. +For more on training environment variables, see the `SageMaker Training Toolkit `_. Save the Model -------------- @@ -115,7 +115,7 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f: torch.save(model.state_dict(), f) -After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data +After your training job is complete, SageMaker compresses and uploads the serialized model to S3, and your model data will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator. If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model. @@ -566,12 +566,76 @@ The function should return a byte array of data serialized to content_type. The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY. It accepts response content types of "application/json", "text/csv", and "application/x-npy". -Working with Existing Model Data and Training Jobs -================================================== -Attach to existing training jobs +Bring your own model +==================== + +You can deploy a PyTorch model that you trained outside of SageMaker by using the ``PyTorchModel`` class. +Typically, you save a PyTorch model as a file with extension ``.pt`` or ``.pth``. +To do this, you need to: + +* Write an inference script. +* Create the directory structure for your model files. +* Create the ``PyTorchModel`` object. + +Write an inference script +------------------------- + +You must create an inference script that implements (at least) the ``model_fn`` function that calls the loaded model to get a prediction. + +**Note**: If you use elastic inference with PyTorch, you can use the default ``model_fn`` implementation provided in the serving container. + +Optionally, you can also implement ``input_fn`` and ``output_fn`` to process input and output, +and ``predict_fn`` to customize how the model server gets predictions from the loaded model. +For information about how to write an inference script, see `Serve a PyTorch Model <#serve-a-pytorch-model>`_. +Save the inference script in the same folder where you saved your PyTorch model. +Pass the filename of the inference script as the ``entry_point`` parameter when you create the ``PyTorchModel`` object. + +Create the directory structure for your model files +--------------------------------------------------- + +You have to create a directory structure and place your model files in the correct location. +The ``PyTorchModel`` constructor packs the files into a ``tar.gz`` file and uploads it to S3. + +The directory structure where you saved your PyTorch model should look something like the following: + +**Note:** This directory struture is for PyTorch versions 1.2 and higher. +For the directory structure for versions 1.1 and lower, +see `For versions 1.1 and lower <#for-versions-1.1-and-lower>`_. + +:: + + | my_model + | |--model.pth + | + | code + | |--inference.py + | |--requirements.txt + +Where ``requirments.txt`` is an optional file that specifies dependencies on third-party libraries. + +Create a ``PyTorchModel`` object -------------------------------- +Now call the :class:`sagemaker.pytorch.model.PyTorchModel` constructor to create a model object, and then call its ``deploy()`` method to deploy your model for inference. + +.. code:: python + + from sagemaker import get_execution_role + role = get_execution_role() + + pytorch_model = PyTorchModel(model_data='s3://my-bucket/my-path/model.tar.gz', role=role, + entry_point='inference.py') + + predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1) + + +Now you can call the ``predict()`` method to get predictions from your deployed model. + +*********************************************** +Attach an estimator to an existing training job +*********************************************** + You can attach a PyTorch Estimator to an existing training job using the ``attach`` method. @@ -592,69 +656,6 @@ The ``attach`` method accepts the following arguments: - ``sagemaker_session:`` The Session used to interact with SageMaker -Deploy Endpoints from model data --------------------------------- - -In addition to attaching to existing training jobs, you can deploy models directly from model data in S3. -The following code sample shows how to do this, using the ``PyTorchModel`` class. - -.. code:: python - - pytorch_model = PyTorchModel(model_data='s3://bucket/model.tar.gz', role='SageMakerRole', - entry_point='transform_script.py') - - predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1) - -The PyTorchModel constructor takes the following arguments: - -- ``model_dat:`` An S3 location of a SageMaker model data - .tar.gz file -- ``image:`` A Docker image URI -- ``role:`` An IAM role name or Arn for SageMaker to access AWS - resources on your behalf. -- ``predictor_cls:`` A function to - call to create a predictor. If not None, ``deploy`` will return the - result of invoking this function on the created endpoint name -- ``env:`` Environment variables to run with - ``image`` when hosted in SageMaker. -- ``name:`` The model name. If None, a default model name will be - selected on each ``deploy.`` -- ``entry_point:`` Path (absolute or relative) to the Python file - which should be executed as the entry point to model hosting. -- ``source_dir:`` Optional. Path (absolute or relative) to a - directory with any other training source code dependencies including - the entry point file. Structure within this directory will be - preserved when training on SageMaker. -- ``enable_cloudwatch_metrics:`` Optional. If true, training - and hosting containers will generate Cloudwatch metrics under the - AWS/SageMakerContainer namespace. -- ``container_log_level:`` Log level to use within the container. - Valid values are defined in the Python logging module. -- ``code_location:`` Optional. Name of the S3 bucket where your - custom code will be uploaded to. If not specified, will use the - SageMaker default bucket created by sagemaker.Session. -- ``sagemaker_session:`` The SageMaker Session - object, used for SageMaker interaction - -Your model data must be a .tar.gz file in S3. SageMaker Training Job model data is saved to .tar.gz files in S3, -however if you have local data you want to deploy, you can prepare the data yourself. - -Assuming you have a local directory containg your model data named "my_model" you can tar and gzip compress the file and -upload to S3 using the following commands: - -:: - - tar -czf model.tar.gz my_model - aws s3 cp model.tar.gz s3://my-bucket/my-path/model.tar.gz - -This uploads the contents of my_model to a gzip compressed tar file to S3 in the bucket "my-bucket", with the key -"my-path/model.tar.gz". - -To run this command, you'll need the AWS CLI tool installed. Please refer to our `FAQ`_ for more information on -installing this. - -.. _FAQ: ../../../README.rst#faq - ************************* PyTorch Training Examples ************************* diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 292bb40cbe..ab10e5f004 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1481,12 +1481,14 @@ def __init__( >>> |----- test.py You can assign entry_point='src/train.py'. - source_dir (str): Path (absolute, relative, or an S3 URI) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory are - preserved when training on Amazon SageMaker. If 'git_config' is - provided, 'source_dir' should be a relative location to a - directory in the Git repo. .. admonition:: Example + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git + repo. + .. admonition:: Example With the following GitHub repo directory structure: diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index d10396769e..cde228331c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -663,13 +663,14 @@ def __init__( >>> |----- test.py You can assign entry_point='src/inference.py'. - source_dir (str): Path (absolute or relative) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory will - be preserved when training on SageMaker. If 'git_config' is - provided, 'source_dir' should be a relative location to a - directory in the Git repo. If the directory points to S3, no - code will be uploaded and the S3 location will be used instead. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git repo. + If the directory points to S3, no code will be uploaded and the S3 location + will be used instead. .. admonition:: Example With the following GitHub repo directory structure: diff --git a/src/sagemaker/model_monitor/dataset_format.py b/src/sagemaker/model_monitor/dataset_format.py index 0e400f9841..f4c9c0b967 100644 --- a/src/sagemaker/model_monitor/dataset_format.py +++ b/src/sagemaker/model_monitor/dataset_format.py @@ -58,4 +58,4 @@ def sagemaker_capture_json(): dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor. """ - return {"sagemaker_capture_json": {}} + return {"sagemakerCaptureJson": {}} diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 1ab3fe2e27..2b0956f90c 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -72,10 +72,11 @@ def __init__( entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. - source_dir (str): Path (absolute or relative) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory are - preserved when training on Amazon SageMaker. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 6d28ae1f4b..040bb51ebd 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -70,7 +70,8 @@ def __init__( output_kms_key (str): The KMS key ID for processing job outputs (default: None). max_runtime_in_seconds (int): Timeout in seconds (default: None). After this amount of time, Amazon SageMaker terminates the job, - regardless of its current status. + regardless of its current status. If `max_runtime_in_seconds` is not + specified, the default value is 24 hours. base_job_name (str): Prefix for processing job name. If not specified, the processor generates a default job name, based on the processing image name and current timestamp. @@ -309,7 +310,8 @@ def __init__( output_kms_key (str): The KMS key ID for processing job outputs (default: None). max_runtime_in_seconds (int): Timeout in seconds (default: None). After this amount of time, Amazon SageMaker terminates the job, - regardless of its current status. + regardless of its current status. If `max_runtime_in_seconds` is not + specified, the default value is 24 hours. base_job_name (str): Prefix for processing name. If not specified, the processor generates a default job name, based on the processing image name and current timestamp. diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 02efd3a3be..3c658d471a 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -68,10 +68,11 @@ def __init__( entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. - source_dir (str): Path (absolute or relative) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory are - preserved when training on Amazon SageMaker. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index ab6b956cc7..e66f45d154 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -109,10 +109,11 @@ def __init__( framework (sagemaker.rl.RLFramework): Framework (MXNet or TensorFlow) you want to be used as a toolkit backed for reinforcement learning training. - source_dir (str): Path (absolute or relative) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory is - preserved when training on Amazon SageMaker. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f03e40e97b..d922866d2b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2583,6 +2583,18 @@ def wait_for_tuning_job(self, job, poll=5): self._check_job_status(job, desc, "HyperParameterTuningJobStatus") return desc + def describe_transform_job(self, job_name): + """Calls the DescribeTransformJob API for the given job name + and returns the response. + + Args: + job_name (str): The name of the transform job to describe. + + Returns: + dict: A dictionary response with the transform job description. + """ + return self.sagemaker_client.describe_transform_job(TransformJobName=job_name) + def wait_for_transform_job(self, job, poll=5): """Wait for an Amazon SageMaker transform job to complete. diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 2b38122adc..27e7d27d27 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -69,10 +69,11 @@ def __init__( framework_version (str): Scikit-learn version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators - source_dir (str): Path (absolute or relative) to a directory with - any other training source code dependencies aside from the entry - point file (default: None). Structure within this directory are - preserved when training on Amazon SageMaker. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 00b779b573..e12633b349 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -570,11 +570,12 @@ def create_model( should be executed as the entry point to training. If not specified and ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If ``endpoint_type`` is also ``None``, then the training entry point is used. - source_dir (str): Path (absolute or relative) to a directory with any other serving - source code dependencies aside from the entry point file. If not specified and - ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If - ``endpoint_type`` is also ``None``, then the model source directory from training - is used. + source_dir (str): Path (absolute or relative or an S3 URI ) to a directory with any + other serving source code dependencies aside from the entry point file. If + ``source_dir`` is an S3 URI, it must point to a tar.gz file. If not specified + and ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If + ``endpoint_type`` is also ``None``, then the model source directory from + training is used. dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container. If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index ec24d426a2..c8c14cf870 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -75,9 +75,11 @@ def __init__( framework_version (str): XGBoost version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#xgboost-sagemaker-estimators - source_dir (str): Path (absolute or relative) to a directory with any other training - source code dependencies aside from the entry point file (default: None). - Structure within this directory are preserved when training on Amazon SageMaker. + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 8eec086a11..19bf38b4b5 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -97,8 +97,8 @@ def test_transform_mxnet( ): transformer.wait() - job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job( - TransformJobName=transformer.latest_transform_job.name + job_desc = transformer.sagemaker_session.describe_transform_job( + job_name=transformer.latest_transform_job.name ) assert kms_key_arn == job_desc["TransformResources"]["VolumeKmsKeyId"] assert output_filter == job_desc["DataProcessing"]["OutputFilter"] @@ -323,8 +323,8 @@ def test_stop_transform_job(mxnet_estimator, mxnet_transform_input, cpu_instance transformer.stop_transform_job() - desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job( - TransformJobName=latest_transform_job_name + desc = transformer.latest_transform_job.sagemaker_session.describe_transform_job( + job_name=latest_transform_job_name ) assert desc["TransformJobStatus"] == "Stopped" @@ -393,9 +393,7 @@ def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type ) assert model_desc["EnableNetworkIsolation"] - job_desc = sagemaker_session.sagemaker_client.describe_transform_job( - TransformJobName=job_name - ) + job_desc = sagemaker_session.describe_transform_job(job_name=job_name) assert job_desc["TransformOutput"]["S3OutputPath"] == output_path assert job_desc["TransformOutput"]["KmsKeyId"] == kms_key assert job_desc["TransformResources"]["VolumeKmsKeyId"] == kms_key diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 439e51067d..1c0a2fbd10 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -228,16 +228,16 @@ def test_fit_ndarray(time, sagemaker_session): labels = [99, 85, 87, 2] pca.fit(pca.record_set(np.array(train), np.array(labels))) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr" ) mock_s3.Object.assert_any_call( - BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest".format(TIMESTAMP) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest" ) assert mock_object.put.call_count == 4 diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 9eee574cf2..1c6388c38b 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -735,7 +735,7 @@ def test_invalid_framework_accelerator(): def test_invalid_framework_accelerator_with_neo(): - error_message = "Neo does not support Amazon Elastic Inference.".format(MOCK_FRAMEWORK) + error_message = "Neo does not support Amazon Elastic Inference." # accelerator was chosen for unsupported framework with pytest.raises(ValueError) as error: fw_utils.create_image_uri( From 187d4f03612f43638c299959a9dd4db0fbb5e039 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 19 May 2020 16:31:40 -0700 Subject: [PATCH 10/14] address comments --- src/sagemaker/__init__.py | 12 +++++++----- src/sagemaker/estimator.py | 9 +++------ src/sagemaker/fw_utils.py | 4 ++-- src/sagemaker/model.py | 4 +--- src/sagemaker/session.py | 6 ++---- tests/unit/test_mxnet.py | 5 +---- 6 files changed, 16 insertions(+), 24 deletions(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 8714690762..5c8736f0d9 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -15,6 +15,7 @@ import logging import importlib_metadata +import sys from sagemaker import estimator, parameter, tuner # noqa: F401 from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor # noqa: F401 @@ -63,8 +64,9 @@ __version__ = importlib_metadata.version("sagemaker") -logging.getLogger("sagemaker").warning( - "SageMaker Python SDK v2 will no longer support Python 2. " - "Please see https://github.com/aws/sagemaker-python-sdk/issues/1459 " - "for more information" -) +if sys.version[0] == "2": + logging.getLogger("sagemaker").warning( + "SageMaker Python SDK v2 will no longer support Python 2. " + "Please see https://github.com/aws/sagemaker-python-sdk/issues/1459 " + "for more information" + ) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ab10e5f004..a929563b8b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -39,6 +39,7 @@ UploadedCode, validate_source_dir, _region_supports_debugger, + parameter_v2_rename_warning, ) from sagemaker.job import _Job from sagemaker.local import LocalSession @@ -1273,9 +1274,7 @@ def __init__( https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries (default: ``None``). """ - warnings.warn( - "Parameter 'image_name' will be renamed to 'image_uri' in SageMaker Python SDK v2." - ) + logging.warning(parameter_v2_rename_warning("image_name", "image_uri")) self.image_name = image_name self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} super(Estimator, self).__init__( @@ -1641,9 +1640,7 @@ def __init__( self.code_location = code_location self.image_name = image_name if image_name is not None: - warnings.warn( - "Parameter 'image_name' will be renamed to 'image_uri' in SageMaker Python SDK v2." - ) + logging.warning(parameter_v2_rename_warning("image_name", "image_uri")) self.uploaded_code = None diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 855fde8fb8..9141ae8c72 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -662,8 +662,8 @@ def python_deprecation_warning(framework, latest_supported_version): def parameter_v2_rename_warning(v1_parameter_name, v2_parameter_name): """ Args: - v1_parameter_name: - v2_parameter_name: + v1_parameter_name: parameter name used in SageMaker Python SDK v1 + v2_parameter_name: parameter name used in SageMaker Python SDK v2 """ return PARAMETER_V2_RENAME_WARNING.format( v1_parameter_name=v1_parameter_name, v2_parameter_name=v2_parameter_name diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index cde228331c..e049b6d1f5 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -108,9 +108,7 @@ def __init__( model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked """ - LOGGER.warning( - "Parameter 'image' will be renamed to 'image_uri' in SageMaker Python SDK v2." - ) + LOGGER.warning(fw_utils.parameter_v2_rename_warning("image", "image_uri")) self.model_data = model_data self.image = image diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d922866d2b..c58801a886 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -183,6 +183,7 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): # Generate a tuple for each file that we want to upload of the form (local_path, s3_key). LOGGER.warning( "'upload_data' method will be deprecated in favor of 'S3Uploader' class " + "(https://sagemaker.readthedocs.io/en/stable/s3.html#sagemaker.s3.S3Uploader) " "in SageMaker Python SDK v2." ) @@ -237,6 +238,7 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None): """ LOGGER.warning( "'upload_string_as_file_body' method will be deprecated in favor of 'S3Uploader' class " + "(https://sagemaker.readthedocs.io/en/stable/s3.html#sagemaker.s3.S3Uploader) " "in SageMaker Python SDK v2." ) @@ -3333,10 +3335,6 @@ def get_execution_role(sagemaker_session=None): Returns: (str): The role ARN """ - LOGGER.warning( - "'get_execution_role' will be renamed to 'notebook_execution_role' " - "in SageMaker Python SDK v2." - ) if not sagemaker_session: sagemaker_session = Session() diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 84f8136389..e073462a3f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -846,7 +846,7 @@ def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session): assert mx.enable_sagemaker_metrics -def test_custom_image_estimator_deploy(sagemaker_session, caplog): +def test_custom_image_estimator_deploy(sagemaker_session): custom_image = "mycustomimage:latest" mx = MXNet( entry_point=SCRIPT_PATH, @@ -858,6 +858,3 @@ def test_custom_image_estimator_deploy(sagemaker_session, caplog): mx.fit(inputs="s3://mybucket/train", job_name="new_name") model = mx.create_model(image=custom_image) assert model.image == custom_image - - warning_message = "Parameter 'image' will be renamed to 'image_uri' in SageMaker Python SDK v2." - assert warning_message in caplog.text From 6cf532242fd1481039de4f31eabbfa015c563580 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 19 May 2020 18:21:23 -0700 Subject: [PATCH 11/14] fix pylint error --- src/sagemaker/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 5c8736f0d9..122e10e927 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -14,8 +14,8 @@ from __future__ import absolute_import import logging -import importlib_metadata import sys +import importlib_metadata from sagemaker import estimator, parameter, tuner # noqa: F401 from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor # noqa: F401 From bba35541f4a279dd1df7eba7bbfce5653db62bd3 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Wed, 20 May 2020 11:37:45 -0700 Subject: [PATCH 12/14] remove eval_metrics warning --- doc/using_tf.rst | 33 +++++++++++++++++++-- src/sagemaker/amazon/kmeans.py | 11 ------- src/sagemaker/amazon/randomcutforest.py | 11 ------- src/sagemaker/automl/automl.py | 2 +- src/sagemaker/s3.py | 4 ++- tests/unit/sagemaker/automl/test_auto_ml.py | 2 +- tests/unit/test_kmeans.py | 8 +---- tests/unit/test_randomcutforest.py | 7 +---- 8 files changed, 37 insertions(+), 41 deletions(-) diff --git a/doc/using_tf.rst b/doc/using_tf.rst index 3a8fcd5186..a222fb1160 100644 --- a/doc/using_tf.rst +++ b/doc/using_tf.rst @@ -133,6 +133,34 @@ In your training script the channels will be stored in environment variables ``S ``output_path``. +Use third-party libraries +------------------------- + +If there are other packages you want to use with your script, you can use a ``requirements.txt`` to install other dependencies at runtime. + +For training, support for installing packages using ``requirements.txt`` varies by TensorFlow version as follows: + +- For TensorFlow 1.11 or newer using Script Mode without Horovod, TensorFlow 1.15.2 with Python 3.7 or newer, and TensorFlow 2.2 or newer: + - Include a ``requirements.txt`` file in the same directory as your training script. + - You must specify this directory using the ``source_dir`` argument when creating a TensorFlow estimator. +- For older versions of TensorFlow using Script Mode with Horovod: + - Write a shell script for your entry point that first calls ``pip install -r requirements.txt``, then runs your training script. + - For an example of using shell scripts, see `this example notebook `__. +- For older versions of TensorFlow using Legacy Mode: + - Specify the path to your ``requirements.txt`` file using the ``requirements_file`` argument. + +For serving, support for installing packages using ``requirements.txt`` varies by TensorFlow version as follows: + +- For TensorFlow 1.11 or newer: + - Include a ``requirements.txt`` file in the ``code`` directory. +- For older versions of TensorFlow: + - Specify the path to your ``requirements.txt`` file using the ``SAGEMAKER_REQUIREMENTS`` environment variable. + +A ``requirements.txt`` file is a text file that contains a list of items that are installed by using ``pip install``. +You can also specify the version of an item to install. +For information about the format of a ``requirements.txt`` file, see `Requirements Files `__ in the pip documentation. + + Create an Estimator =================== @@ -215,7 +243,7 @@ Calling ``fit`` starts a SageMaker training job. The training job will execute t - starts asynchronous training If the ``wait=False`` flag is passed to ``fit``, then it returns immediately. The training job continues running -asynchronously. Later, a Tensorflow estimator can be obtained by attaching to the existing training job. +asynchronously. Later, a TensorFlow estimator can be obtained by attaching to the existing training job. If the training job is not finished, it starts showing the standard output of training and wait until it completes. After attaching, the estimator can be deployed as usual. @@ -882,8 +910,7 @@ in the following code: You can also bring in external dependencies to help with your data processing. There are 2 ways to do this: -1. If you included ``requirements.txt`` in your ``source_dir`` or in - your dependencies, the container installs the Python dependencies at runtime using ``pip install -r``: +1. If your model archive contains ``code/requirements.txt``, the container will install the Python dependencies at runtime using ``pip install -r``. .. code:: diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 35fa236378..d6b4ddda20 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,8 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import -import logging - from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -25,9 +23,6 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger("sagemaker") - - class KMeans(AmazonAlgorithmEstimatorBase): """Placeholder docstring""" @@ -159,12 +154,6 @@ def __init__( self.center_factor = center_factor self.eval_metrics = eval_metrics - if eval_metrics is not None: - logger.warning( - "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " - "in SageMaker Python SDK v2." - ) - def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): """Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing the latest s3 model data produced by this Estimator. diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index f6a4e3c5c2..8e188c95ae 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,8 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import -import logging - from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -25,9 +23,6 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger("sagemaker") - - class RandomCutForest(AmazonAlgorithmEstimatorBase): """Placeholder docstring""" @@ -124,12 +119,6 @@ def __init__( self.num_trees = num_trees self.eval_metrics = eval_metrics - if eval_metrics is not None: - logger.warning( - "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " - "in SageMaker Python SDK v2." - ) - def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): """Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing the latest s3 model data produced by this Estimator. diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index ed9805f75e..794e03aee1 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -39,7 +39,7 @@ def __init__( encrypt_inter_container_traffic=False, vpc_config=None, problem_type=None, - max_candidates=500, + max_candidates=None, max_runtime_per_training_job_in_seconds=None, total_job_runtime_in_seconds=None, job_objective=None, diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index 88b4f1a410..9d63b3c793 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -150,7 +150,9 @@ def read_file(s3_uri, session=None): str: The body of the file. """ - _session_v2_rename_warning(session) + if session is not None: + _session_v2_rename_warning(session) + sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index c6b240e2b2..8ef0cd31da 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -32,7 +32,7 @@ DEFAULT_S3_INPUT_DATA = "s3://{}/data".format(BUCKET_NAME) DEFAULT_OUTPUT_PATH = "s3://{}/".format(BUCKET_NAME) LOCAL_DATA_PATH = "file://data" -DEFAULT_MAX_CANDIDATES = 500 +DEFAULT_MAX_CANDIDATES = None DEFAULT_JOB_NAME = "automl-{}".format(TIMESTAMP) JOB_NAME = "default-job-name" diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 0013e4147d..555b78b451 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -82,7 +82,7 @@ def test_init_required_named(sagemaker_session): assert kmeans.k == ALL_REQ_ARGS["k"] -def test_all_hyperparameters(sagemaker_session, caplog): +def test_all_hyperparameters(sagemaker_session): kmeans = KMeans( sagemaker_session=sagemaker_session, init_method="random", @@ -110,12 +110,6 @@ def test_all_hyperparameters(sagemaker_session, caplog): force_dense="True", ) - warning_message = ( - "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " - "in SageMaker Python SDK v2." - ) - assert warning_message in caplog.text - def test_image(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index 9ab9d5f603..d960e45f46 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -89,7 +89,7 @@ def test_init_required_named(sagemaker_session): assert randomcutforest.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] -def test_all_hyperparameters(sagemaker_session, caplog): +def test_all_hyperparameters(sagemaker_session): randomcutforest = RandomCutForest( sagemaker_session=sagemaker_session, num_trees=NUM_TREES, @@ -102,11 +102,6 @@ def test_all_hyperparameters(sagemaker_session, caplog): num_trees=str(NUM_TREES), eval_metrics='["accuracy", "precision_recall_fscore"]', ) - warning_message = ( - "Parameter 'eval_metrics' hyperparameter will be deprecated for 1P estimators " - "in SageMaker Python SDK v2." - ) - assert warning_message in caplog.text def test_image(sagemaker_session): From 4aa98e3258e6098c779dc77016b49ba53b77e7e1 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Wed, 20 May 2020 11:45:09 -0700 Subject: [PATCH 13/14] wrap session warning --- src/sagemaker/s3.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index 9d63b3c793..14c0ad1451 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -70,7 +70,9 @@ def upload(local_path, desired_s3_uri, kms_key=None, session=None): The S3 uri of the uploaded file(s). """ - _session_v2_rename_warning(session) + if session is not None: + _session_v2_rename_warning(session) + sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=desired_s3_uri) if kms_key is not None: @@ -97,7 +99,9 @@ def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, session= str: The S3 uri of the uploaded file(s). """ - _session_v2_rename_warning(session) + if session is not None: + _session_v2_rename_warning(session) + sagemaker_session = session or Session() bucket, key = parse_s3_url(desired_s3_uri) @@ -125,7 +129,9 @@ def download(s3_uri, local_path, kms_key=None, session=None): using the default AWS configuration chain. """ - _session_v2_rename_warning(session) + if session is not None: + _session_v2_rename_warning(session) + sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) if kms_key is not None: @@ -171,7 +177,9 @@ def list(s3_uri, session=None): [str]: The list of S3 URIs in the given S3 base uri. """ - _session_v2_rename_warning(session) + if session is not None: + _session_v2_rename_warning(session) + sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) From 17e7540102943d2e1b683b63126a46062eb6ea8c Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Wed, 20 May 2020 12:14:50 -0700 Subject: [PATCH 14/14] fix black error --- src/sagemaker/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index 14c0ad1451..316b1e002f 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -179,7 +179,7 @@ def list(s3_uri, session=None): """ if session is not None: _session_v2_rename_warning(session) - + sagemaker_session = session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri)