Skip to content

change: handle "train_*" renames in v2 migration tool #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
modifiers.airflow.ModelConfigImageURIRenamer(),
modifiers.renamed_params.DistributionParameterRenamer(),
modifiers.renamed_params.S3SessionRenamer(),
modifiers.renamed_params.EstimatorCreateModelImageURIRenamer(),
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
modifiers.training_params.TrainPrefixRemover(),
]

IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
renamed_params,
tf_legacy_mode,
tfs,
training_params,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to modify Predictor code to be compatible
with version 2.0 and later of the SageMaker Python SDK.
"""
"""Classes to handle renames for version 2.0 and later of the SageMaker Python SDK."""
from __future__ import absolute_import

import ast
Expand Down
97 changes: 97 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/training_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to handle training renames for version 2.0 and later of the SageMaker Python SDK."""
from __future__ import absolute_import

from sagemaker.cli.compatibility.v2.modifiers import matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

ESTIMATORS = {
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
"Estimator": ("sagemaker.estimator",),
"EstimatorBase": ("sagemaker.estimator",),
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
"Framework": ("sagemaker.estimator",),
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
}

PARAMS = (
"train_instance_count",
"train_instance_type",
"train_max_run",
"train_max_run_wait",
"train_use_spot_instances",
"train_volume_size",
"train_volume_kms_key",
)


class TrainPrefixRemover(Modifier):
"""A class to remove the redundant 'train' prefix in estimator parameters."""

def node_should_be_modified(self, node):
"""Checks if the node is an estimator constructor and contains any relevant parameters.

This looks for the following parameters:

- ``train_instance_count``
- ``train_instance_type``
- ``train_max_run``
- ``train_max_run_wait``
- ``train_use_spot_instances``
- ``train_volume_kms_key``
- ``train_volume_size``

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` matches the relevant function calls and
contains the parameter to be renamed.
"""
return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node)

def _has_train_parameter(self, node):
"""Checks if at least one of the node's keywords is prefixed with 'train'."""
for kw in node.keywords:
if kw.arg in PARAMS:
return True

return False

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to remove the 'train' prefix from its keywords.

Args:
node (ast.Call): a node that represents an estimator constructor.
"""
for kw in node.keywords:
if kw.arg in PARAMS:
kw.arg = kw.arg.replace("train_", "")
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import itertools

import pasta

from sagemaker.cli.compatibility.v2.modifiers import training_params
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call

ESTIMATORS_TO_NAMESPACES = {
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
"Estimator": ("sagemaker.estimator",),
"EstimatorBase": ("sagemaker.estimator",),
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
"Framework": ("sagemaker.estimator",),
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
}

PARAMS_WITH_VALUES = (
"train_instance_count=1",
"train_instance_type='ml.c4.xlarge'",
"train_max_run=8 * 60 * 60",
"train_max_run_wait=1 * 60 * 60",
"train_use_spot_instances=True",
"train_volume_size=30",
"train_volume_kms_key='key'",
)


def _estimators():
for estimator, namespaces in ESTIMATORS_TO_NAMESPACES.items():
yield estimator

for namespace in namespaces:
yield ".".join((namespace, estimator))


def test_node_should_be_modified():
modifier = training_params.TrainPrefixRemover()

for estimator in _estimators():
for param in PARAMS_WITH_VALUES:
call = ast_call("{}({})".format(estimator, param))
assert modifier.node_should_be_modified(call)


def test_node_should_be_modified_no_params():
modifier = training_params.TrainPrefixRemover()

for estimator in _estimators():
call = ast_call("{}()".format(estimator))
assert not modifier.node_should_be_modified(call)


def test_node_should_be_modified_random_function_call():
modifier = training_params.TrainPrefixRemover()
assert not modifier.node_should_be_modified(ast_call("Session()"))


def test_modify_node():
modifier = training_params.TrainPrefixRemover()

for params in _parameter_combinations():
node = ast_call("Estimator({})".format(params))
modifier.modify_node(node)

expected = "Estimator({})".format(params).replace("train_", "")
assert expected == pasta.dump(node)


def _parameter_combinations():
for subset_length in range(1, len(PARAMS_WITH_VALUES) + 1):
for subset in itertools.combinations(PARAMS_WITH_VALUES, subset_length):
yield ", ".join(subset)