Skip to content

Commit c50295e

Browse files
authored
Merge branch 'zwei' into add-string-deserializer
2 parents 459995b + 9fd784e commit c50295e

File tree

3 files changed

+275
-10
lines changed

3 files changed

+275
-10
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+119-10
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,34 @@ def modify_node(self, node):
6868
keyword.arg = self.new_param_name
6969

7070

71+
class MethodParamRenamer(ParamRenamer):
72+
"""Abstract class to handle parameter renames for methods that belong to objects.
73+
74+
This differs from ``ParamRenamer`` in that a node for a standalone function call
75+
(i.e. where ``node.func`` is an ``ast.Name`` rather than an ``ast.Attribute``) is not modified.
76+
"""
77+
78+
def node_should_be_modified(self, node):
79+
"""Checks if the node matches any of the relevant functions and
80+
contains the parameter to be renamed.
81+
82+
This looks for a call of the form ``<object>.<method>``, and
83+
assumes the method cannot be called on its own.
84+
85+
Args:
86+
node (ast.Call): a node that represents a function call. For more,
87+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
88+
89+
Returns:
90+
bool: If the ``ast.Call`` matches the relevant function calls and
91+
contains the parameter to be renamed.
92+
"""
93+
if isinstance(node.func, ast.Name):
94+
return False
95+
96+
return super(MethodParamRenamer, self).node_should_be_modified(node)
97+
98+
7199
class DistributionParameterRenamer(ParamRenamer):
72100
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
73101
MXNet and TensorFlow estimators.
@@ -100,7 +128,7 @@ def new_param_name(self):
100128
return "distribution"
101129

102130

103-
class S3SessionRenamer(ParamRenamer):
131+
class S3SessionRenamer(MethodParamRenamer):
104132
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
105133
``S3Uploader`` and ``S3Downloader``.
106134
@@ -139,15 +167,6 @@ def new_param_name(self):
139167
"""The new name for the SageMaker session argument."""
140168
return "sagemaker_session"
141169

142-
def node_should_be_modified(self, node):
143-
"""Checks if the node is one of the S3 utility functions and
144-
contains the ``session`` parameter.
145-
"""
146-
if isinstance(node.func, ast.Name):
147-
return False
148-
149-
return super(S3SessionRenamer, self).node_should_be_modified(node)
150-
151170

152171
class EstimatorImageURIRenamer(ParamRenamer):
153172
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
@@ -209,3 +228,93 @@ def old_param_name(self):
209228
def new_param_name(self):
210229
"""The new name for the image URI argument."""
211230
return "image_uri"
231+
232+
233+
class EstimatorCreateModelImageURIRenamer(MethodParamRenamer):
234+
"""A class to rename ``image`` to ``image_uri`` in estimator ``create_model()`` methods."""
235+
236+
@property
237+
def calls_to_modify(self):
238+
"""A mapping of ``create_model`` to common variable names for estimators."""
239+
return {
240+
"create_model": (
241+
"estimator",
242+
"chainer",
243+
"mxnet",
244+
"mx",
245+
"pytorch",
246+
"rl",
247+
"sklearn",
248+
"tensorflow",
249+
"tf",
250+
"xgboost",
251+
"xgb",
252+
)
253+
}
254+
255+
@property
256+
def old_param_name(self):
257+
"""The previous name for the image URI argument."""
258+
return "image"
259+
260+
@property
261+
def new_param_name(self):
262+
"""The new name for the the image URI argument."""
263+
return "image_uri"
264+
265+
266+
class SessionCreateModelImageURIRenamer(MethodParamRenamer):
267+
"""A class to rename ``primary_container_image`` to ``image_uri``.
268+
269+
This looks for the following calls:
270+
271+
- ``sagemaker_session.create_model_from_job()``
272+
- ``sess.create_model_from_job()``
273+
"""
274+
275+
@property
276+
def calls_to_modify(self):
277+
"""A mapping of ``create_model_from_job`` to common variable names for Session."""
278+
return {
279+
"create_model_from_job": ("sagemaker_session", "sess"),
280+
}
281+
282+
@property
283+
def old_param_name(self):
284+
"""The previous name for the image URI argument."""
285+
return "primary_container_image"
286+
287+
@property
288+
def new_param_name(self):
289+
"""The new name for the the image URI argument."""
290+
return "image_uri"
291+
292+
293+
class SessionCreateEndpointImageURIRenamer(MethodParamRenamer):
294+
"""A class to rename ``deployment_image`` to ``image_uri``.
295+
296+
This looks for the following calls:
297+
298+
- ``sagemaker_session.endpoint_from_job()``
299+
- ``sess.endpoint_from_job()``
300+
- ``sagemaker_session.endpoint_from_model_data()``
301+
- ``sess.endpoint_from_model_data()``
302+
"""
303+
304+
@property
305+
def calls_to_modify(self):
306+
"""A mapping of the ``endpoint_from_*`` functions to common variable names for Session."""
307+
return {
308+
"endpoint_from_job": ("sagemaker_session", "sess"),
309+
"endpoint_from_model_data": ("sagemaker_session", "sess"),
310+
}
311+
312+
@property
313+
def old_param_name(self):
314+
"""The previous name for the image URI argument."""
315+
return "deployment_image"
316+
317+
@property
318+
def new_param_name(self):
319+
"""The new name for the the image URI argument."""
320+
return "image_uri"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
ESTIMATORS = (
21+
"estimator",
22+
"chainer",
23+
"mxnet",
24+
"mx",
25+
"pytorch",
26+
"rl",
27+
"sklearn",
28+
"tensorflow",
29+
"tf",
30+
"xgboost",
31+
"xgb",
32+
)
33+
34+
35+
def test_node_should_be_modified():
36+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
37+
38+
for estimator in ESTIMATORS:
39+
call = "{}.create_model(image='my-image:latest')".format(estimator)
40+
assert modifier.node_should_be_modified(ast_call(call))
41+
42+
43+
def test_node_should_be_modified_no_distribution():
44+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
45+
46+
for estimator in ESTIMATORS:
47+
call = "{}.create_model()".format(estimator)
48+
assert not modifier.node_should_be_modified(ast_call(call))
49+
50+
51+
def test_node_should_be_modified_random_function_call():
52+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
53+
assert not modifier.node_should_be_modified(ast_call("create_model()"))
54+
55+
56+
def test_modify_node():
57+
node = ast_call("estimator.create_model(image=my_image)")
58+
modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
59+
modifier.modify_node(node)
60+
61+
expected = "estimator.create_model(image_uri=my_image)"
62+
assert expected == pasta.dump(node)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
CREATE_MODEL_TEMPLATES = (
21+
"sagemaker_session.create_model_from_job({})",
22+
"sess.create_model_from_job({})",
23+
)
24+
25+
CREATE_ENDPOINT_TEMPLATES = (
26+
"sagemaker_session.endpoint_from_job({})",
27+
"sagemaker_session.endpoint_from_model_data({})",
28+
"sess.endpoint_from_job({})",
29+
"sess.endpoint_from_model_data({})",
30+
)
31+
32+
33+
def test_create_model_node_should_be_modified():
34+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
35+
36+
for template in CREATE_MODEL_TEMPLATES:
37+
call = ast_call(template.format("primary_container_image=my_image"))
38+
assert modifier.node_should_be_modified(call)
39+
40+
41+
def test_create_model_node_should_be_modified_no_image():
42+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
43+
44+
for template in CREATE_MODEL_TEMPLATES:
45+
call = ast_call(template.format(""))
46+
assert not modifier.node_should_be_modified(call)
47+
48+
49+
def test_create_model_node_should_be_modified_random_function_call():
50+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
51+
assert not modifier.node_should_be_modified(ast_call("create_model()"))
52+
53+
54+
def test_create_model_modify_node():
55+
modifier = renamed_params.SessionCreateModelImageURIRenamer()
56+
57+
for template in CREATE_MODEL_TEMPLATES:
58+
call = ast_call(template.format("primary_container_image=my_image"))
59+
modifier.modify_node(call)
60+
61+
expected = template.format("image_uri=my_image")
62+
assert expected == pasta.dump(call)
63+
64+
65+
def test_create_endpoint_node_should_be_modified():
66+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
67+
68+
for template in CREATE_ENDPOINT_TEMPLATES:
69+
call = ast_call(template.format("deployment_image=my_image"))
70+
assert modifier.node_should_be_modified(call)
71+
72+
73+
def test_create_endpoint_node_should_be_modified_no_image():
74+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
75+
76+
for template in CREATE_ENDPOINT_TEMPLATES:
77+
call = ast_call(template.format(""))
78+
assert not modifier.node_should_be_modified(call)
79+
80+
81+
def test_create_endpoint_node_should_be_modified_random_function_call():
82+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
83+
assert not modifier.node_should_be_modified(ast_call("create_endpoint()"))
84+
85+
86+
def test_create_endpoint_modify_node():
87+
modifier = renamed_params.SessionCreateEndpointImageURIRenamer()
88+
89+
for template in CREATE_ENDPOINT_TEMPLATES:
90+
call = ast_call(template.format("deployment_image=my_image"))
91+
modifier.modify_node(call)
92+
93+
expected = template.format("image_uri=my_image")
94+
assert expected == pasta.dump(call)

0 commit comments

Comments
 (0)