Skip to content

Commit 2d05070

Browse files
author
Chuyang Deng
committed
update docstring and remove unused algorithm name
1 parent 3f08e2e commit 2d05070

File tree

2 files changed

+15
-45
lines changed

2 files changed

+15
-45
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py

+1-23
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes to modify image_uris.retrieve() code to be compatible
14-
with version 2.0 and later of the SageMaker Python SDK.
15-
"""
13+
"""Classes to modify image uri retrieve methods for Python SDK v2.0 and later."""
1614
from __future__ import absolute_import
1715

1816
import ast
@@ -33,8 +31,6 @@
3331

3432
ALGORITHM_NAME_FROM_REPO = {
3533
"blazingtext": "blazingtext",
36-
"sagemaker-rl-mxnet": "coach-mxnet",
37-
"sagemaker-rl-tensorflow": ["coach-tensorflow", "ray-tensorflow"],
3834
"factorization-machine": "factorization-machines",
3935
"forecasting-deepar": "forecasting-deepar",
4036
"image-classification": "image-classification",
@@ -49,12 +45,8 @@
4945
"object-detection": "object-detection",
5046
"pca": "pca",
5147
"randomcutforest": "randomcutforest",
52-
"sagemaker-rl-ray-container": "ray-pytorch",
5348
"semantic-segmentation": "semantic-segmentation",
5449
"seq2seq": "seq2seq",
55-
"sagemaker-scikit-learn": "sklearn",
56-
"sagemaker-sparkml-serving": "sparkml-serving",
57-
"sagemaker-rl-vw-container": "vw",
5850
"sagemaker-xgboost": "xgboost",
5951
"xgboost-neo": "xgboost-neo",
6052
}
@@ -97,13 +89,6 @@ def modify_node(self, node):
9789
if kw.arg == "repo_name":
9890
arg = kw.value.s
9991
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
100-
if isinstance(modified_arg, list):
101-
logger.warning(
102-
"There are more than one value mapping to {}, {} will be used".format(
103-
arg, modified_arg[0]
104-
)
105-
)
106-
modified_arg = modified_arg[0]
10792
original_args[0] = ast.Str(modified_arg)
10893
elif kw.arg == "repo_region":
10994
original_args[1] = ast.Str(kw.value.s)
@@ -115,13 +100,6 @@ def modify_node(self, node):
115100
if len(node.args) > 1:
116101
arg = node.args[1].s
117102
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
118-
if isinstance(modified_arg, list):
119-
logger.warning(
120-
"There are more than one value mapping to {}, {} will be used".format(
121-
arg, modified_arg[0]
122-
)
123-
)
124-
modified_arg = modified_arg[0]
125103
original_args[0] = ast.Str(modified_arg)
126104
if len(node.args) > 2:
127105
original_args[2] = ast.Str(node.args[2].s)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_algorithm_image_uris.py

+14-22
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
@pytest.fixture
2323
def methods():
2424
return (
25-
"get_image_uri('us-west-2', 'kmeans')",
26-
"sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-scikil-learn')",
27-
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-scikil-learn')",
28-
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikil-learn', repo_version='1')",
25+
"get_image_uri('us-west-2', 'sagemaker-xgboost')",
26+
"sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')",
27+
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')",
28+
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')",
2929
)
3030

3131

@@ -54,38 +54,30 @@ def test_methodnode_should_be_modified_random_call():
5454
def test_method_modify_node(methods, caplog):
5555
modifier = image_uris.ImageURIRetrieveRefactor()
5656

57-
method = "get_image_uri('us-west-2', 'sagemaker-scikit-learn')"
57+
method = "get_image_uri('us-west-2', 'sagemaker-xgboost')"
5858
node = ast_call(method)
5959
modifier.modify_node(node)
60-
assert "image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
60+
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
6161

62-
method = "amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikit-learn')"
62+
method = "amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost')"
6363
node = ast_call(method)
6464
modifier.modify_node(node)
65-
assert "image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
65+
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
6666

67-
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-scikit-learn')"
67+
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')"
6868
node = ast_call(method)
6969
modifier.modify_node(node)
70-
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
70+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
7171

72-
method = (
73-
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-scikit-learn')"
74-
)
75-
node = ast_call(method)
76-
modifier.modify_node(node)
77-
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
78-
79-
method = "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikit-learn', repo_version='1')"
72+
method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')"
8073
node = ast_call(method)
8174
modifier.modify_node(node)
82-
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2', '1')" == pasta.dump(node)
75+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
8376

84-
method = "get_image_uri('us-west-2', 'sagemaker-rl-tensorflow')"
77+
method = "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')"
8578
node = ast_call(method)
8679
modifier.modify_node(node)
87-
assert "image_uris.retrieve('coach-tensorflow', 'us-west-2')"
88-
assert "There are more than one value mapping to" in caplog.text
80+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node)
8981

9082

9183
def test_import_from_node_should_be_modified_image_uris_input(import_statements):

0 commit comments

Comments
 (0)