Skip to content

Commit 5d0acaa

Browse files
authored
Merge branch 'zwei' into remove-scipy
2 parents 87512d9 + 614fe7e commit 5d0acaa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+899
-5040
lines changed

buildspec-unittests.yml

+9-2
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,12 @@ phases:
1818
- start_time=`date +%s`
1919
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
2020
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
21-
tox -e py27,py36,py37 --parallel all -- tests/unit
22-
- ./ci-scripts/displaytime.sh 'py27,py36,py37 unit' $start_time
21+
tox -e py36,py37 --parallel all -- tests/unit
22+
- ./ci-scripts/displaytime.sh 'py36,py37 unit' $start_time
23+
24+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
25+
- start_time=`date +%s`
26+
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
27+
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
28+
IGNORE_COVERAGE=- tox -e py27 --parallel all -- tests/unit
29+
- ./ci-scripts/displaytime.sh 'py27 unit' $start_time

doc/sagemaker.tensorflow.rst

-16
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,6 @@ TensorFlow Estimator
1010
:undoc-members:
1111
:show-inheritance:
1212

13-
TensorFlow Model
14-
----------------
15-
16-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
17-
:members:
18-
:undoc-members:
19-
:show-inheritance:
20-
21-
TensorFlow Predictor
22-
--------------------
23-
24-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
25-
:members:
26-
:undoc-members:
27-
:show-inheritance:
28-
2913
TensorFlow Serving Model
3014
------------------------
3115

setup.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def read_version():
105105
],
106106
install_requires=required_packages,
107107
extras_require=extras,
108-
entry_points={"console_scripts": ["sagemaker=sagemaker.cli.main:main"]},
108+
entry_points={
109+
"console_scripts": [
110+
"sagemaker=sagemaker.cli.main:main",
111+
"sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main",
112+
]
113+
},
109114
include_package_data=True, # TODO-reinvent-2019 [knakad]: Remove after rule_configs is in PyPI
110115
)

tools/compatibility/v2/ast_transformer.py renamed to src/sagemaker/cli/compatibility/v2/ast_transformer.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import ast
1717

18-
from modifiers import framework_version
18+
from sagemaker.cli.compatibility.v2 import modifiers
1919

20-
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
20+
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.framework_version.FrameworkVersionEnforcer(),
22+
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
23+
]
2124

2225

2326
class ASTTransformer(ast.NodeTransformer):
@@ -38,4 +41,6 @@ def visit_Call(self, node):
3841
"""
3942
for function_checker in FUNCTION_CALL_MODIFIERS:
4043
function_checker.check_and_modify_node(node)
44+
45+
ast.fix_missing_locations(node)
4146
return node

tools/compatibility/v2/files.py renamed to src/sagemaker/cli/compatibility/v2/files.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pasta
2222

23-
from ast_transformer import ASTTransformer
23+
from sagemaker.cli.compatibility.v2.ast_transformer import ASTTransformer
2424

2525
LOGGER = logging.getLogger(__name__)
2626

@@ -48,6 +48,18 @@ def update(self):
4848
updated code to an output file.
4949
"""
5050

51+
def _make_output_dirs_if_needed(self):
52+
"""Checks if the directory path for ``self.output_path`` exists,
53+
and creates the directories if not. This function also logs a warning if
54+
``self.output_path`` already exists.
55+
"""
56+
output_dir = os.path.dirname(self.output_path)
57+
if output_dir and not os.path.exists(output_dir):
58+
os.makedirs(output_dir)
59+
60+
if os.path.exists(self.output_path):
61+
LOGGER.warning("Overwriting file %s", self.output_path)
62+
5163

5264
class PyFileUpdater(FileUpdater):
5365
"""A class for updating Python (``*.py``) files."""
@@ -88,12 +100,7 @@ def _write_output_file(self, output):
88100
Args:
89101
output (ast.Module): AST to save as the output file.
90102
"""
91-
output_dir = os.path.dirname(self.output_path)
92-
if output_dir and not os.path.exists(output_dir):
93-
os.makedirs(output_dir)
94-
95-
if os.path.exists(self.output_path):
96-
LOGGER.warning("Overwriting file %s", self.output_path)
103+
self._make_output_dirs_if_needed()
97104

98105
with open(self.output_path, "w") as output_file:
99106
output_file.write(pasta.dump(output))
@@ -168,12 +175,7 @@ def _write_output_file(self, output):
168175
Args:
169176
output (dict): JSON to save as the output file.
170177
"""
171-
output_dir = os.path.dirname(self.output_path)
172-
if output_dir and not os.path.exists(output_dir):
173-
os.makedirs(output_dir)
174-
175-
if os.path.exists(self.output_path):
176-
LOGGER.warning("Overwriting file %s", self.output_path)
178+
self._make_output_dirs_if_needed()
177179

178180
with open(self.output_path, "w") as output_file:
179181
json.dump(output, output_file, indent=1)

tools/compatibility/v2/modifiers/__init__.py renamed to src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for modifying AST nodes"""
1414
from __future__ import absolute_import
15+
16+
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
framework_version,
18+
tf_legacy_mode,
19+
)

tools/compatibility/v2/modifiers/framework_version.py renamed to src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from modifiers.modifier import Modifier
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1919

2020
FRAMEWORK_DEFAULTS = {
2121
"Chainer": "4.1.0",
@@ -66,8 +66,7 @@ def _is_framework_constructor(self, node):
6666
"""
6767
# Check for <Framework> call
6868
if isinstance(node.func, ast.Name):
69-
if node.func.id in FRAMEWORK_CLASSES:
70-
return True
69+
return node.func.id in FRAMEWORK_CLASSES
7170

7271
# Check for sagemaker.<framework>.<Framework> call
7372
ends_with_framework_constructor = (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2."""
14+
# TODO: handle fit(run_tensorboard_locally=True)
15+
from __future__ import absolute_import
16+
17+
import ast
18+
19+
import six
20+
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
24+
class TensorFlowLegacyModeConstructorUpgrader(Modifier):
25+
"""A class to turn legacy mode parameters into hyperparameters when
26+
instantiating a TensorFlow estimator.
27+
"""
28+
29+
LEGACY_MODE_PARAMETERS = (
30+
"checkpoint_path",
31+
"evaluation_steps",
32+
"requirements_file",
33+
"training_steps",
34+
)
35+
36+
def node_should_be_modified(self, node):
37+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.
38+
39+
This looks for the following formats:
40+
41+
- ``TensorFlow``
42+
- ``sagemaker.tensorflow.TensorFlow``
43+
44+
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
45+
and (2) if ``py_version`` is ``py2`` or not specified.
46+
47+
Args:
48+
node (ast.Call): a node that represents a function call. For more,
49+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
50+
51+
Returns:
52+
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with legacy mode.
53+
"""
54+
return self._is_tf_constructor(node) and self._is_legacy_mode(node)
55+
56+
def _is_tf_constructor(self, node):
57+
"""Checks if the ``ast.Call`` node represents a call of the form
58+
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
59+
"""
60+
# Check for TensorFlow()
61+
if isinstance(node.func, ast.Name):
62+
return node.func.id == "TensorFlow"
63+
64+
# Check for sagemaker.tensorflow.TensorFlow()
65+
ends_with_tensorflow_constructor = (
66+
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
67+
)
68+
69+
is_in_tensorflow_module = (
70+
isinstance(node.func.value, ast.Attribute)
71+
and node.func.value.attr == "tensorflow"
72+
and isinstance(node.func.value.value, ast.Name)
73+
and node.func.value.value.id == "sagemaker"
74+
)
75+
76+
return ends_with_tensorflow_constructor and is_in_tensorflow_module
77+
78+
def _is_legacy_mode(self, node):
79+
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
80+
script_mode = False
81+
py_version = "py2"
82+
83+
for kw in node.keywords:
84+
if kw.arg == "script_mode":
85+
script_mode = bool(kw.value.value)
86+
if kw.arg == "py_version":
87+
py_version = kw.value.s
88+
89+
return not (py_version.startswith("py3") or script_mode)
90+
91+
def modify_node(self, node):
92+
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
93+
into hyperparameters and set ``script_mode=False``.
94+
95+
The parameters that are converted into hyperparameters:
96+
97+
- ``training_steps``
98+
- ``evaluation_steps``
99+
- ``checkpoint_path``
100+
- ``requirements_file``
101+
102+
Args:
103+
node (ast.Call): a node that represents a TensorFlow constructor.
104+
"""
105+
base_hps = {}
106+
additional_hps = {}
107+
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
108+
109+
for kw in node.keywords:
110+
if kw.arg == "script_mode":
111+
# remove here because is set to False later regardless of current value
112+
kw_to_remove.append(kw)
113+
if kw.arg == "hyperparameters" and kw.value:
114+
base_hps = dict(zip(kw.value.keys, kw.value.values))
115+
kw_to_remove.append(kw)
116+
if kw.arg in self.LEGACY_MODE_PARAMETERS and kw.value:
117+
hp_key = self._hyperparameter_key_for_param(kw.arg)
118+
additional_hps[hp_key] = kw.value
119+
kw_to_remove.append(kw)
120+
121+
self._remove_keywords(node, kw_to_remove)
122+
self._add_updated_hyperparameters(node, base_hps, additional_hps)
123+
124+
node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))
125+
126+
def _hyperparameter_key_for_param(self, arg):
127+
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
128+
name = "sagemaker_requirements" if arg == "requirements_file" else arg
129+
return ast.Str(s=name)
130+
131+
def _remove_keywords(self, node, keywords):
132+
"""Removes the keywords from the ``ast.Call`` node."""
133+
for kw in keywords:
134+
node.keywords.remove(kw)
135+
136+
def _add_updated_hyperparameters(self, node, base_hps, additional_hps):
137+
"""Combines and adds the hyperparameters to the ``ast.Call`` node's keywords."""
138+
base_hps.update(additional_hps)
139+
updated_hp_keyword = self._to_ast_keyword(base_hps)
140+
141+
if updated_hp_keyword:
142+
node.keywords.append(updated_hp_keyword)
143+
144+
def _to_ast_keyword(self, hps):
145+
"""Returns an ``ast.keyword`` for the ``hyperparameters`` kwarg if there are any."""
146+
if hps:
147+
keys, values = zip(*six.iteritems(hps))
148+
return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values))
149+
150+
return None

tools/compatibility/v2/sagemaker_upgrade_v2.py renamed to src/sagemaker/cli/compatibility/v2/sagemaker_upgrade_v2.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
import argparse
1717
import os
1818

19-
import files
19+
from sagemaker.cli.compatibility.v2 import files
2020

2121
_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater}
2222

2323

2424
def _update_file(input_file, output_file):
25-
"""Update a file to be compatible with v2 of the SageMaker Python SDK,
25+
"""Updates a file to be compatible with v2 of the SageMaker Python SDK,
2626
and write the updated source to the output file.
2727
2828
Args:
@@ -51,10 +51,10 @@ def _update_file(input_file, output_file):
5151

5252

5353
def _parse_args():
54-
"""Parses CLI arguments"""
54+
"""Parses CLI arguments."""
5555
parser = argparse.ArgumentParser(
5656
description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. "
57-
"Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py"
57+
"Simple usage: sagemaker-upgrade-v2 --in-file foo.py --out-file bar.py"
5858
)
5959
parser.add_argument(
6060
"--in-file",
@@ -71,6 +71,7 @@ def _parse_args():
7171
return parser.parse_args()
7272

7373

74-
if __name__ == "__main__":
74+
def main():
75+
"""Parses the CLI arguments and executes the file update."""
7576
args = _parse_args()
7677
_update_file(args.in_file, args.out_file)

src/sagemaker/cli/tensorflow.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ def create_model(self, model_url):
6868
Args:
6969
model_url:
7070
"""
71-
from sagemaker.tensorflow.model import TensorFlowModel
71+
from sagemaker.tensorflow.serving import Model
7272

73-
return TensorFlowModel(
73+
return Model(
7474
model_data=model_url,
7575
role=self.role_name,
7676
entry_point=self.script,
77-
py_version=self.python,
7877
name=self.endpoint_name,
7978
env=self.environment,
8079
)

src/sagemaker/tensorflow/__init__.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using TensorFlow and TensorFlow Serving with Amazon SageMaker."""
1414
from __future__ import absolute_import
1515

16-
import sys
17-
import os
18-
19-
# Hack to use our local copy of tensorflow_serving.apis, which contains the protobuf-generated
20-
# classes for tensorflow serving. Currently tensorflow_serving_api can only be pip-installed for
21-
# python 2.
22-
sys.path.append(os.path.dirname(__file__))
23-
24-
from sagemaker.tensorflow.estimator import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
25-
TensorFlow,
26-
)
27-
from sagemaker.tensorflow.model import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
28-
TensorFlowModel,
29-
TensorFlowPredictor,
30-
)
16+
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)

0 commit comments

Comments
 (0)