Skip to content

Commit 5cc1083

Browse files
authored
Merge branch 'zwei' into v2-compability-tests
2 parents bec3f3c + de54a76 commit 5cc1083

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3195
-6149
lines changed

buildspec-unittests.yml

+9-2
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,12 @@ phases:
1818
- start_time=`date +%s`
1919
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
2020
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
21-
tox -e py27,py36,py37 --parallel all -- tests/unit
22-
- ./ci-scripts/displaytime.sh 'py27,py36,py37 unit' $start_time
21+
tox -e py36,py37 --parallel all -- tests/unit
22+
- ./ci-scripts/displaytime.sh 'py36,py37 unit' $start_time
23+
24+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
25+
- start_time=`date +%s`
26+
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
27+
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
28+
IGNORE_COVERAGE=- tox -e py27 --parallel all -- tests/unit
29+
- ./ci-scripts/displaytime.sh 'py27 unit' $start_time

doc/conf.py

-29
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import pkg_resources
17-
import sys
1817
from datetime import datetime
19-
from unittest.mock import MagicMock
20-
21-
22-
class Mock(MagicMock):
23-
@classmethod
24-
def __getattr__(cls, name):
25-
"""
26-
Args:
27-
name:
28-
"""
29-
if name == "__version__":
30-
return "1.4.0"
31-
else:
32-
return MagicMock()
33-
34-
35-
MOCK_MODULES = [
36-
"tensorflow",
37-
"tensorflow.core",
38-
"tensorflow.core.framework",
39-
"tensorflow.python",
40-
"tensorflow.python.framework",
41-
"tensorflow_serving",
42-
"tensorflow_serving.apis",
43-
"scipy",
44-
"scipy.sparse",
45-
]
46-
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
4718

4819
project = u"sagemaker"
4920
version = pkg_resources.require(project)[0].version

doc/frameworks/tensorflow/sagemaker.tensorflow.rst

+2-18
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,18 @@ TensorFlow Estimator
1010
:undoc-members:
1111
:show-inheritance:
1212

13-
TensorFlow Model
14-
----------------
15-
16-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
17-
:members:
18-
:undoc-members:
19-
:show-inheritance:
20-
21-
TensorFlow Predictor
22-
--------------------
23-
24-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
25-
:members:
26-
:undoc-members:
27-
:show-inheritance:
28-
2913
TensorFlow Serving Model
3014
------------------------
3115

32-
.. autoclass:: sagemaker.tensorflow.serving.Model
16+
.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
3317
:members:
3418
:undoc-members:
3519
:show-inheritance:
3620

3721
TensorFlow Serving Predictor
3822
----------------------------
3923

40-
.. autoclass:: sagemaker.tensorflow.serving.Predictor
24+
.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
4125
:members:
4226
:undoc-members:
4327
:show-inheritance:

doc/requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
sphinx==2.2.2
22
numpy
3-
scipy
43
requests==2.20

setup.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def read_version():
3535
# Declare minimal set for installation
3636
required_packages = [
3737
"boto3>=1.13.6",
38+
"google-pasta",
3839
"numpy>=1.9.0",
3940
"protobuf>=3.1",
40-
"scipy>=0.19.0",
4141
"protobuf3-to-dict>=0.1.5",
4242
"smdebug-rulesconfig==0.1.4",
4343
"importlib-metadata>=1.4.0",
@@ -52,7 +52,7 @@ def read_version():
5252
"docker-compose>=1.25.2",
5353
"PyYAML>=5.3, <6", # PyYAML version has to match docker-compose requirements
5454
],
55-
"tensorflow": ["tensorflow>=1.3.0"],
55+
"scipy": ["scipy>=0.19.0"],
5656
}
5757
# Meta dependency groups
5858
extras["all"] = [item for group in extras.values() for item in group]
@@ -105,6 +105,11 @@ def read_version():
105105
],
106106
install_requires=required_packages,
107107
extras_require=extras,
108-
entry_points={"console_scripts": ["sagemaker=sagemaker.cli.main:main"]},
108+
entry_points={
109+
"console_scripts": [
110+
"sagemaker=sagemaker.cli.main:main",
111+
"sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main",
112+
]
113+
},
109114
include_package_data=True, # TODO-reinvent-2019 [knakad]: Remove after rule_configs is in PyPI
110115
)

src/sagemaker/amazon/common.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from __future__ import absolute_import
1515

1616
import io
17+
import logging
1718
import struct
1819
import sys
1920

2021
import numpy as np
21-
from scipy.sparse import issparse
2222

2323
from sagemaker.amazon.record_pb2 import Record
24+
from sagemaker.utils import DeferredError
2425

2526

2627
class numpy_to_record_serializer(object):
@@ -171,8 +172,16 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None):
171172
array:
172173
labels:
173174
"""
174-
175-
if not issparse(array):
175+
try:
176+
import scipy
177+
except ImportError as e:
178+
logging.warning(
179+
"scipy failed to import. Sparse matrix functions will be impaired or broken."
180+
)
181+
# Any subsequent attempt to use scipy will raise the ImportError
182+
scipy = DeferredError(e)
183+
184+
if not scipy.sparse.issparse(array):
176185
raise TypeError("Array must be sparse")
177186

178187
# Validate shape of array and labels, resolve array and label types
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools to assist with compatibility between SageMaker Python SDK versions."""
14+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools to assist with upgrading to v2 of the SageMaker Python SDK."""
14+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2 import modifiers
19+
20+
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.framework_version.FrameworkVersionEnforcer(),
22+
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
23+
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
24+
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
25+
modifiers.tfs.TensorFlowServingConstructorRenamer(),
26+
]
27+
28+
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
29+
30+
IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()]
31+
32+
33+
class ASTTransformer(ast.NodeTransformer):
34+
"""An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and
35+
modifies nodes to upgrade the given SageMaker Python SDK code.
36+
"""
37+
38+
def visit_Call(self, node):
39+
"""Visits an ``ast.Call`` node and returns a modified node, if needed.
40+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
41+
42+
Args:
43+
node (ast.Call): a node that represents a function call.
44+
45+
Returns:
46+
ast.Call: a node that represents a function call, which has
47+
potentially been modified from the original input.
48+
"""
49+
for function_checker in FUNCTION_CALL_MODIFIERS:
50+
function_checker.check_and_modify_node(node)
51+
52+
ast.fix_missing_locations(node)
53+
return node
54+
55+
def visit_Import(self, node):
56+
"""Visits an ``ast.Import`` node and returns a modified node, if needed.
57+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
58+
59+
Args:
60+
node (ast.Import): a node that represents an import statement.
61+
62+
Returns:
63+
ast.Import: a node that represents an import statement, which has
64+
potentially been modified from the original input.
65+
"""
66+
for import_checker in IMPORT_MODIFIERS:
67+
import_checker.check_and_modify_node(node)
68+
69+
ast.fix_missing_locations(node)
70+
return node
71+
72+
def visit_ImportFrom(self, node):
73+
"""Visits an ``ast.ImportFrom`` node and returns a modified node, if needed.
74+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
75+
76+
Args:
77+
node (ast.ImportFrom): a node that represents an import statement.
78+
79+
Returns:
80+
ast.ImportFrom: a node that represents an import statement, which has
81+
potentially been modified from the original input.
82+
"""
83+
for import_checker in IMPORT_FROM_MODIFIERS:
84+
import_checker.check_and_modify_node(node)
85+
86+
ast.fix_missing_locations(node)
87+
return node

0 commit comments

Comments
 (0)