Skip to content

breaking: move ShuffleConfig from sagemaker.session to sagemaker.inputs #1786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
modifiers.training_params.TrainPrefixRemover(),
modifiers.training_input.TrainingInputConstructorRefactor(),
modifiers.training_input.ShuffleConfigModuleRenamer(),
modifiers.serde.SerdeConstructorRenamer(),
]

Expand All @@ -51,6 +52,7 @@
modifiers.predictors.PredictorImportFromRenamer(),
modifiers.tfs.TensorFlowServingImportFromRenamer(),
modifiers.training_input.TrainingInputImportFromRenamer(),
modifiers.training_input.ShuffleConfigImportFromRenamer(),
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
modifiers.serde.SerdeImportFromPredictorRenamer(),
]
Expand Down
70 changes: 70 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/training_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,73 @@ def modify_node(self, node):
if node.module == "sagemaker.session":
node.module = "sagemaker.inputs"
return node


class ShuffleConfigModuleRenamer(Modifier):
"""A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``."""

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node instantiates a class of interest.

This looks for the following calls:

- ``sagemaker.session.ShuffleConfig``
- ``session.ShuffleConfig``

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` instantiates a class of interest.
"""
if isinstance(node.func, ast.Name):
return False

return matching.matches_name_or_namespaces(
node, "ShuffleConfig", ("sagemaker.session", "session")
)

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``.

Args:
node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig``
constructor.

Returns:
ast.Call: the original node, with its namespace changed to use the ``inputs`` module.
"""
_rename_namespace(node, "session")
return node


class ShuffleConfigImportFromRenamer(Modifier):
"""A class to update import statements of ``ShuffleConfig``."""

def node_should_be_modified(self, node):
"""Checks if the import statement imports ``sagemaker.session.ShuffleConfig``.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the import statement imports ``sagemaker.session.ShuffleConfig``.
"""
return node.module == "sagemaker.session" and any(
name.name == "ShuffleConfig" for name in node.names
)

def modify_node(self, node):
"""Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``.
"""
node.module = "sagemaker.inputs"
return node
20 changes: 18 additions & 2 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
found in a specified AugmentedManifestFile.
target_attribute_name (str): The name of the attribute will be predicted (classified)
in a SageMaker AutoML job. It is required if the input is for SageMaker AutoML job.
shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on
this channel. See the SageMaker API documentation for more info:
shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables
shuffling on this channel. See the SageMaker API documentation for more info:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""
self.config = {
Expand Down Expand Up @@ -102,6 +102,22 @@ def __init__(
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}


class ShuffleConfig(object):
"""For configuring channel shuffling using a seed.

For more detail, see the AWS documentation:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""

def __init__(self, seed):
"""Create a ShuffleConfig.

Args:
seed (long): the long value used to seed the shuffled sequence.
"""
self.seed = seed


class FileSystemInput(object):
"""Amazon SageMaker channel configurations for file system data sources.

Expand Down
15 changes: 0 additions & 15 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,21 +3362,6 @@ def get_execution_role(sagemaker_session=None):
raise ValueError(message.format(arn))


class ShuffleConfig(object):
"""
Used to configure channel shuffling using a seed. See SageMaker documentation for
more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""

def __init__(self, seed):
"""
Create a ShuffleConfig.
Args:
seed (long): the long value used to seed the shuffled sequence.
"""
self.seed = seed


def _create_model_request(
name, role, container_def=None, tags=None
): # pylint: disable=redefined-outer-name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import training_input
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import


@pytest.fixture
def constructors():
return (
"sagemaker.session.ShuffleConfig(seed)",
"session.ShuffleConfig(seed)",
)


@pytest.fixture
def modified_constructors(constructors):
return [c.replace("session", "inputs") for c in constructors]


def test_constructor_node_should_be_modified(constructors):
modifier = training_input.ShuffleConfigModuleRenamer()
for constructor in constructors:
node = ast_call(constructor)
assert modifier.node_should_be_modified(node)


def test_constructor_node_should_be_modified_random_call():
modifier = training_input.ShuffleConfigModuleRenamer()
node = ast_call("FileSystemInput()")
assert not modifier.node_should_be_modified(node)


def test_constructor_modify_node(constructors, modified_constructors):
modifier = training_input.ShuffleConfigModuleRenamer()

for before, expected in zip(constructors, modified_constructors):
node = ast_call(before)
modifier.modify_node(node)
assert expected == pasta.dump(node)


def test_import_from_node_should_be_modified_training_input():
modifier = training_input.ShuffleConfigImportFromRenamer()
node = ast_import("from sagemaker.session import ShuffleConfig")
assert modifier.node_should_be_modified(node)


def test_import_from_node_should_be_modified_random_import():
modifier = training_input.ShuffleConfigImportFromRenamer()
node = ast_import("from sagemaker.session import Session")
assert not modifier.node_should_be_modified(node)


def test_import_from_modify_node():
modifier = training_input.ShuffleConfigImportFromRenamer()
node = ast_import("from sagemaker.session import ShuffleConfig")

modifier.modify_node(node)
assert "from sagemaker.inputs import ShuffleConfig" == pasta.dump(node)
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from sagemaker import TrainingInput, utils, vpc_utils
from sagemaker.algorithm import AlgorithmEstimator
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
from sagemaker.inputs import ShuffleConfig
from sagemaker.model import FrameworkModel
from sagemaker.predictor import Predictor
from sagemaker.session import ShuffleConfig
from sagemaker.transformer import Transformer

MODEL_DATA = "s3://bucket/model.tar.gz"
Expand Down