diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 199efe1a3e..f9e97a0635 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -30,6 +30,10 @@ modifiers.airflow.ModelConfigImageURIRenamer(), modifiers.renamed_params.DistributionParameterRenamer(), modifiers.renamed_params.S3SessionRenamer(), + modifiers.renamed_params.EstimatorCreateModelImageURIRenamer(), + modifiers.renamed_params.SessionCreateModelImageURIRenamer(), + modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(), + modifiers.training_params.TrainPrefixRemover(), ] IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index 519b8edf4c..da2f0460b2 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -21,4 +21,5 @@ renamed_params, tf_legacy_mode, tfs, + training_params, ) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py index 7c978dd135..7570740d9f 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py @@ -10,9 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Classes to modify Predictor code to be compatible -with version 2.0 and later of the SageMaker Python SDK. -""" +"""Classes to handle renames for version 2.0 and later of the SageMaker Python SDK.""" from __future__ import absolute_import import ast diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py new file mode 100644 index 0000000000..3368e694d4 --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py @@ -0,0 +1,97 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to handle training renames for version 2.0 and later of the SageMaker Python SDK.""" +from __future__ import absolute_import + +from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + +ESTIMATORS = { + "AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"), + "AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",), + "Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"), + "Estimator": ("sagemaker.estimator",), + "EstimatorBase": ("sagemaker.estimator",), + "FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"), + "Framework": ("sagemaker.estimator",), + "IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"), + "KMeans": ("sagemaker", "sagemaker.amazon.kmeans"), + "KNN": ("sagemaker", "sagemaker.amazon.knn"), + "LDA": ("sagemaker", "sagemaker.amazon.lda"), + "LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"), + "MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"), + "NTM": ("sagemaker", "sagemaker.amazon.ntm"), + "Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"), + "PCA": ("sagemaker", "sagemaker.amazon.pca"), + "PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"), + "RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"), + "RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"), + "SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"), + "TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"), + "XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"), +} + +PARAMS = ( + "train_instance_count", + "train_instance_type", + "train_max_run", + "train_max_run_wait", + "train_use_spot_instances", + "train_volume_size", + "train_volume_kms_key", +) + + +class TrainPrefixRemover(Modifier): + """A class to remove the redundant 'train' prefix in estimator parameters.""" + + def node_should_be_modified(self, node): + """Checks if the node is an estimator constructor and contains any relevant parameters. + + This looks for the following parameters: + + - ``train_instance_count`` + - ``train_instance_type`` + - ``train_max_run`` + - ``train_max_run_wait`` + - ``train_use_spot_instances`` + - ``train_volume_kms_key`` + - ``train_volume_size`` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` matches the relevant function calls and + contains the parameter to be renamed. + """ + return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node) + + def _has_train_parameter(self, node): + """Checks if at least one of the node's keywords is prefixed with 'train'.""" + for kw in node.keywords: + if kw.arg in PARAMS: + return True + + return False + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to remove the 'train' prefix from its keywords. + + Args: + node (ast.Call): a node that represents an estimator constructor. + """ + for kw in node.keywords: + if kw.arg in PARAMS: + kw.arg = kw.arg.replace("train_", "") diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_params.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_params.py new file mode 100644 index 0000000000..21db661684 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_params.py @@ -0,0 +1,102 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import itertools + +import pasta + +from sagemaker.cli.compatibility.v2.modifiers import training_params +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call + +ESTIMATORS_TO_NAMESPACES = { + "AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"), + "AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",), + "Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"), + "Estimator": ("sagemaker.estimator",), + "EstimatorBase": ("sagemaker.estimator",), + "FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"), + "Framework": ("sagemaker.estimator",), + "IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"), + "KMeans": ("sagemaker", "sagemaker.amazon.kmeans"), + "KNN": ("sagemaker", "sagemaker.amazon.knn"), + "LDA": ("sagemaker", "sagemaker.amazon.lda"), + "LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"), + "MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"), + "NTM": ("sagemaker", "sagemaker.amazon.ntm"), + "Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"), + "PCA": ("sagemaker", "sagemaker.amazon.pca"), + "PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"), + "RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"), + "RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"), + "SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"), + "TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"), + "XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"), +} + +PARAMS_WITH_VALUES = ( + "train_instance_count=1", + "train_instance_type='ml.c4.xlarge'", + "train_max_run=8 * 60 * 60", + "train_max_run_wait=1 * 60 * 60", + "train_use_spot_instances=True", + "train_volume_size=30", + "train_volume_kms_key='key'", +) + + +def _estimators(): + for estimator, namespaces in ESTIMATORS_TO_NAMESPACES.items(): + yield estimator + + for namespace in namespaces: + yield ".".join((namespace, estimator)) + + +def test_node_should_be_modified(): + modifier = training_params.TrainPrefixRemover() + + for estimator in _estimators(): + for param in PARAMS_WITH_VALUES: + call = ast_call("{}({})".format(estimator, param)) + assert modifier.node_should_be_modified(call) + + +def test_node_should_be_modified_no_params(): + modifier = training_params.TrainPrefixRemover() + + for estimator in _estimators(): + call = ast_call("{}()".format(estimator)) + assert not modifier.node_should_be_modified(call) + + +def test_node_should_be_modified_random_function_call(): + modifier = training_params.TrainPrefixRemover() + assert not modifier.node_should_be_modified(ast_call("Session()")) + + +def test_modify_node(): + modifier = training_params.TrainPrefixRemover() + + for params in _parameter_combinations(): + node = ast_call("Estimator({})".format(params)) + modifier.modify_node(node) + + expected = "Estimator({})".format(params).replace("train_", "") + assert expected == pasta.dump(node) + + +def _parameter_combinations(): + for subset_length in range(1, len(PARAMS_WITH_VALUES) + 1): + for subset in itertools.combinations(PARAMS_WITH_VALUES, subset_length): + yield ", ".join(subset)