Skip to content

Commit 4ffa222

Browse files
authored
change: handle "train_*" renames in v2 migration tool (#1684)
1 parent 9fc8a46 commit 4ffa222

File tree

5 files changed

+205
-3
lines changed

5 files changed

+205
-3
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
modifiers.airflow.ModelConfigImageURIRenamer(),
3131
modifiers.renamed_params.DistributionParameterRenamer(),
3232
modifiers.renamed_params.S3SessionRenamer(),
33+
modifiers.renamed_params.EstimatorCreateModelImageURIRenamer(),
34+
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
35+
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
36+
modifiers.training_params.TrainPrefixRemover(),
3337
]
3438

3539
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
renamed_params,
2222
tf_legacy_mode,
2323
tfs,
24+
training_params,
2425
)

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes to modify Predictor code to be compatible
14-
with version 2.0 and later of the SageMaker Python SDK.
15-
"""
13+
"""Classes to handle renames for version 2.0 and later of the SageMaker Python SDK."""
1614
from __future__ import absolute_import
1715

1816
import ast
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to handle training renames for version 2.0 and later of the SageMaker Python SDK."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.cli.compatibility.v2.modifiers import matching
17+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
18+
19+
ESTIMATORS = {
20+
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
21+
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
22+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
23+
"Estimator": ("sagemaker.estimator",),
24+
"EstimatorBase": ("sagemaker.estimator",),
25+
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
26+
"Framework": ("sagemaker.estimator",),
27+
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
28+
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
29+
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
30+
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
31+
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
32+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
33+
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
34+
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
35+
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
36+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
37+
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
38+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
39+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
40+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
41+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
42+
}
43+
44+
PARAMS = (
45+
"train_instance_count",
46+
"train_instance_type",
47+
"train_max_run",
48+
"train_max_run_wait",
49+
"train_use_spot_instances",
50+
"train_volume_size",
51+
"train_volume_kms_key",
52+
)
53+
54+
55+
class TrainPrefixRemover(Modifier):
56+
"""A class to remove the redundant 'train' prefix in estimator parameters."""
57+
58+
def node_should_be_modified(self, node):
59+
"""Checks if the node is an estimator constructor and contains any relevant parameters.
60+
61+
This looks for the following parameters:
62+
63+
- ``train_instance_count``
64+
- ``train_instance_type``
65+
- ``train_max_run``
66+
- ``train_max_run_wait``
67+
- ``train_use_spot_instances``
68+
- ``train_volume_kms_key``
69+
- ``train_volume_size``
70+
71+
Args:
72+
node (ast.Call): a node that represents a function call. For more,
73+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
74+
75+
Returns:
76+
bool: If the ``ast.Call`` matches the relevant function calls and
77+
contains the parameter to be renamed.
78+
"""
79+
return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node)
80+
81+
def _has_train_parameter(self, node):
82+
"""Checks if at least one of the node's keywords is prefixed with 'train'."""
83+
for kw in node.keywords:
84+
if kw.arg in PARAMS:
85+
return True
86+
87+
return False
88+
89+
def modify_node(self, node):
90+
"""Modifies the ``ast.Call`` node to remove the 'train' prefix from its keywords.
91+
92+
Args:
93+
node (ast.Call): a node that represents an estimator constructor.
94+
"""
95+
for kw in node.keywords:
96+
if kw.arg in PARAMS:
97+
kw.arg = kw.arg.replace("train_", "")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import itertools
16+
17+
import pasta
18+
19+
from sagemaker.cli.compatibility.v2.modifiers import training_params
20+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
21+
22+
ESTIMATORS_TO_NAMESPACES = {
23+
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
24+
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
25+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
26+
"Estimator": ("sagemaker.estimator",),
27+
"EstimatorBase": ("sagemaker.estimator",),
28+
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
29+
"Framework": ("sagemaker.estimator",),
30+
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
31+
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
32+
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
33+
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
34+
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
35+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
36+
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
37+
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
38+
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
39+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
40+
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
41+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
42+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
43+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
44+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
45+
}
46+
47+
PARAMS_WITH_VALUES = (
48+
"train_instance_count=1",
49+
"train_instance_type='ml.c4.xlarge'",
50+
"train_max_run=8 * 60 * 60",
51+
"train_max_run_wait=1 * 60 * 60",
52+
"train_use_spot_instances=True",
53+
"train_volume_size=30",
54+
"train_volume_kms_key='key'",
55+
)
56+
57+
58+
def _estimators():
59+
for estimator, namespaces in ESTIMATORS_TO_NAMESPACES.items():
60+
yield estimator
61+
62+
for namespace in namespaces:
63+
yield ".".join((namespace, estimator))
64+
65+
66+
def test_node_should_be_modified():
67+
modifier = training_params.TrainPrefixRemover()
68+
69+
for estimator in _estimators():
70+
for param in PARAMS_WITH_VALUES:
71+
call = ast_call("{}({})".format(estimator, param))
72+
assert modifier.node_should_be_modified(call)
73+
74+
75+
def test_node_should_be_modified_no_params():
76+
modifier = training_params.TrainPrefixRemover()
77+
78+
for estimator in _estimators():
79+
call = ast_call("{}()".format(estimator))
80+
assert not modifier.node_should_be_modified(call)
81+
82+
83+
def test_node_should_be_modified_random_function_call():
84+
modifier = training_params.TrainPrefixRemover()
85+
assert not modifier.node_should_be_modified(ast_call("Session()"))
86+
87+
88+
def test_modify_node():
89+
modifier = training_params.TrainPrefixRemover()
90+
91+
for params in _parameter_combinations():
92+
node = ast_call("Estimator({})".format(params))
93+
modifier.modify_node(node)
94+
95+
expected = "Estimator({})".format(params).replace("train_", "")
96+
assert expected == pasta.dump(node)
97+
98+
99+
def _parameter_combinations():
100+
for subset_length in range(1, len(PARAMS_WITH_VALUES) + 1):
101+
for subset in itertools.combinations(PARAMS_WITH_VALUES, subset_length):
102+
yield ", ".join(subset)

0 commit comments

Comments
 (0)