Skip to content

Commit 15a295f

Browse files
committed
breaking: move ShuffleConfig from sagemaker.session to sagemaker.inputs
1 parent 8b7be01 commit 15a295f

File tree

6 files changed

+165
-18
lines changed

6 files changed

+165
-18
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
38+
modifiers.training_input.ShuffleConfigModuleRenamer(),
3839
modifiers.serde.SerdeConstructorRenamer(),
3940
]
4041

@@ -51,6 +52,7 @@
5152
modifiers.predictors.PredictorImportFromRenamer(),
5253
modifiers.tfs.TensorFlowServingImportFromRenamer(),
5354
modifiers.training_input.TrainingInputImportFromRenamer(),
55+
modifiers.training_input.ShuffleConfigImportFromRenamer(),
5456
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5557
modifiers.serde.SerdeImportFromPredictorRenamer(),
5658
]

src/sagemaker/cli/compatibility/v2/modifiers/training_input.py

+70
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,73 @@ def modify_node(self, node):
100100
if node.module == "sagemaker.session":
101101
node.module = "sagemaker.inputs"
102102
return node
103+
104+
105+
class ShuffleConfigModuleRenamer(Modifier):
106+
"""A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``."""
107+
108+
def node_should_be_modified(self, node):
109+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
110+
111+
This looks for the following calls:
112+
113+
- ``sagemaker.session.ShuffleConfig``
114+
- ``session.ShuffleConfig``
115+
116+
Args:
117+
node (ast.Call): a node that represents a function call. For more,
118+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
119+
120+
Returns:
121+
bool: If the ``ast.Call`` instantiates a class of interest.
122+
"""
123+
if isinstance(node.func, ast.Name):
124+
return False
125+
126+
return matching.matches_name_or_namespaces(
127+
node, "ShuffleConfig", ("sagemaker.session", "session")
128+
)
129+
130+
def modify_node(self, node):
131+
"""Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``.
132+
133+
Args:
134+
node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig``
135+
constructor.
136+
137+
Returns:
138+
ast.Call: the original node, with its namespace changed to use the ``inputs`` module.
139+
"""
140+
_rename_namespace(node, "session")
141+
return node
142+
143+
144+
class ShuffleConfigImportFromRenamer(Modifier):
145+
"""A class to update import statements of ``ShuffleConfig``."""
146+
147+
def node_should_be_modified(self, node):
148+
"""Checks if the import statement imports ``sagemaker.session.ShuffleConfig``.
149+
150+
Args:
151+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
152+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
153+
154+
Returns:
155+
bool: If the import statement imports ``sagemaker.session.ShuffleConfig``.
156+
"""
157+
return node.module == "sagemaker.session" and any(
158+
name.name == "ShuffleConfig" for name in node.names
159+
)
160+
161+
def modify_node(self, node):
162+
"""Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``.
163+
164+
Args:
165+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
166+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
167+
168+
Returns:
169+
ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``.
170+
"""
171+
node.module = "sagemaker.inputs"
172+
return node

src/sagemaker/inputs.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(
7272
found in a specified AugmentedManifestFile.
7373
target_attribute_name (str): The name of the attribute will be predicted (classified)
7474
in a SageMaker AutoML job. It is required if the input is for SageMaker AutoML job.
75-
shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on
76-
this channel. See the SageMaker API documentation for more info:
75+
shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables
76+
shuffling on this channel. See the SageMaker API documentation for more info:
7777
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
7878
"""
7979
self.config = {
@@ -102,6 +102,22 @@ def __init__(
102102
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}
103103

104104

105+
class ShuffleConfig(object):
106+
"""For configuring channel shuffling using a seed.
107+
108+
For more detail, see the AWS documentation:
109+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
110+
"""
111+
112+
def __init__(self, seed):
113+
"""Create a ShuffleConfig.
114+
115+
Args:
116+
seed (long): the long value used to seed the shuffled sequence.
117+
"""
118+
self.seed = seed
119+
120+
105121
class FileSystemInput(object):
106122
"""Amazon SageMaker channel configurations for file system data sources.
107123

src/sagemaker/session.py

-15
Original file line numberDiff line numberDiff line change
@@ -3362,21 +3362,6 @@ def get_execution_role(sagemaker_session=None):
33623362
raise ValueError(message.format(arn))
33633363

33643364

3365-
class ShuffleConfig(object):
3366-
"""
3367-
Used to configure channel shuffling using a seed. See SageMaker documentation for
3368-
more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
3369-
"""
3370-
3371-
def __init__(self, seed):
3372-
"""
3373-
Create a ShuffleConfig.
3374-
Args:
3375-
seed (long): the long value used to seed the shuffled sequence.
3376-
"""
3377-
self.seed = seed
3378-
3379-
33803365
def _create_model_request(
33813366
name, role, container_def=None, tags=None
33823367
): # pylint: disable=redefined-outer-name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import training_input
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def constructors():
24+
return (
25+
"sagemaker.session.ShuffleConfig(seed)",
26+
"session.ShuffleConfig(seed)",
27+
)
28+
29+
30+
@pytest.fixture
31+
def modified_constructors(constructors):
32+
return [c.replace("session", "inputs") for c in constructors]
33+
34+
35+
def test_constructor_node_should_be_modified(constructors):
36+
modifier = training_input.ShuffleConfigModuleRenamer()
37+
for constructor in constructors:
38+
node = ast_call(constructor)
39+
assert modifier.node_should_be_modified(node)
40+
41+
42+
def test_constructor_node_should_be_modified_random_call():
43+
modifier = training_input.ShuffleConfigModuleRenamer()
44+
node = ast_call("FileSystemInput()")
45+
assert not modifier.node_should_be_modified(node)
46+
47+
48+
def test_constructor_modify_node(constructors, modified_constructors):
49+
modifier = training_input.ShuffleConfigModuleRenamer()
50+
51+
for before, expected in zip(constructors, modified_constructors):
52+
node = ast_call(before)
53+
modifier.modify_node(node)
54+
assert expected == pasta.dump(node)
55+
56+
57+
def test_import_from_node_should_be_modified_training_input():
58+
modifier = training_input.ShuffleConfigImportFromRenamer()
59+
node = ast_import("from sagemaker.session import ShuffleConfig")
60+
assert modifier.node_should_be_modified(node)
61+
62+
63+
def test_import_from_node_should_be_modified_random_import():
64+
modifier = training_input.ShuffleConfigImportFromRenamer()
65+
node = ast_import("from sagemaker.session import Session")
66+
assert not modifier.node_should_be_modified(node)
67+
68+
69+
def test_import_from_modify_node():
70+
modifier = training_input.ShuffleConfigImportFromRenamer()
71+
node = ast_import("from sagemaker.session import ShuffleConfig")
72+
73+
modifier.modify_node(node)
74+
assert "from sagemaker.inputs import ShuffleConfig" == pasta.dump(node)

tests/unit/test_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from sagemaker import TrainingInput, utils, vpc_utils
2727
from sagemaker.algorithm import AlgorithmEstimator
2828
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
29+
from sagemaker.inputs import ShuffleConfig
2930
from sagemaker.model import FrameworkModel
3031
from sagemaker.predictor import Predictor
31-
from sagemaker.session import ShuffleConfig
3232
from sagemaker.transformer import Transformer
3333

3434
MODEL_DATA = "s3://bucket/model.tar.gz"

0 commit comments

Comments
 (0)