Skip to content

Commit 72fee9e

Browse files
authored
Merge branch 'zwei' into add-csv-deserializer
2 parents 9ef4e9a + 4ffa222 commit 72fee9e

File tree

10 files changed

+245
-41
lines changed

10 files changed

+245
-41
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
modifiers.airflow.ModelConfigImageURIRenamer(),
3131
modifiers.renamed_params.DistributionParameterRenamer(),
3232
modifiers.renamed_params.S3SessionRenamer(),
33+
modifiers.renamed_params.EstimatorCreateModelImageURIRenamer(),
34+
modifiers.renamed_params.SessionCreateModelImageURIRenamer(),
35+
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
36+
modifiers.training_params.TrainPrefixRemover(),
3337
]
3438

3539
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
renamed_params,
2222
tf_legacy_mode,
2323
tfs,
24+
training_params,
2425
)

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes to modify Predictor code to be compatible
14-
with version 2.0 and later of the SageMaker Python SDK.
15-
"""
13+
"""Classes to handle renames for version 2.0 and later of the SageMaker Python SDK."""
1614
from __future__ import absolute_import
1715

1816
import ast
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to handle training renames for version 2.0 and later of the SageMaker Python SDK."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.cli.compatibility.v2.modifiers import matching
17+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
18+
19+
ESTIMATORS = {
20+
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
21+
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
22+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
23+
"Estimator": ("sagemaker.estimator",),
24+
"EstimatorBase": ("sagemaker.estimator",),
25+
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
26+
"Framework": ("sagemaker.estimator",),
27+
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
28+
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
29+
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
30+
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
31+
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
32+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
33+
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
34+
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
35+
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
36+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
37+
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
38+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
39+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
40+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
41+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
42+
}
43+
44+
PARAMS = (
45+
"train_instance_count",
46+
"train_instance_type",
47+
"train_max_run",
48+
"train_max_run_wait",
49+
"train_use_spot_instances",
50+
"train_volume_size",
51+
"train_volume_kms_key",
52+
)
53+
54+
55+
class TrainPrefixRemover(Modifier):
56+
"""A class to remove the redundant 'train' prefix in estimator parameters."""
57+
58+
def node_should_be_modified(self, node):
59+
"""Checks if the node is an estimator constructor and contains any relevant parameters.
60+
61+
This looks for the following parameters:
62+
63+
- ``train_instance_count``
64+
- ``train_instance_type``
65+
- ``train_max_run``
66+
- ``train_max_run_wait``
67+
- ``train_use_spot_instances``
68+
- ``train_volume_kms_key``
69+
- ``train_volume_size``
70+
71+
Args:
72+
node (ast.Call): a node that represents a function call. For more,
73+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
74+
75+
Returns:
76+
bool: If the ``ast.Call`` matches the relevant function calls and
77+
contains the parameter to be renamed.
78+
"""
79+
return matching.matches_any(node, ESTIMATORS) and self._has_train_parameter(node)
80+
81+
def _has_train_parameter(self, node):
82+
"""Checks if at least one of the node's keywords is prefixed with 'train'."""
83+
for kw in node.keywords:
84+
if kw.arg in PARAMS:
85+
return True
86+
87+
return False
88+
89+
def modify_node(self, node):
90+
"""Modifies the ``ast.Call`` node to remove the 'train' prefix from its keywords.
91+
92+
Args:
93+
node (ast.Call): a node that represents an estimator constructor.
94+
"""
95+
for kw in node.keywords:
96+
if kw.arg in PARAMS:
97+
kw.arg = kw.arg.replace("train_", "")

src/sagemaker/deserializers.py

+29
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,35 @@ def ACCEPT(self):
4343
"""The content type that is expected from the inference endpoint."""
4444

4545

46+
class StringDeserializer(BaseDeserializer):
47+
"""Deserialize data from an inference endpoint into a decoded string."""
48+
49+
ACCEPT = "application/json"
50+
51+
def __init__(self, encoding="UTF-8"):
52+
"""Initialize the string encoding.
53+
54+
Args:
55+
encoding (str): The string encoding to use (default: UTF-8).
56+
"""
57+
self.encoding = encoding
58+
59+
def deserialize(self, data, content_type):
60+
"""Deserialize data from an inference endpoint into a decoded string.
61+
62+
Args:
63+
data (object): Data to be deserialized.
64+
content_type (str): The MIME type of the data.
65+
66+
Returns:
67+
str: The data deserialized into a decoded string.
68+
"""
69+
try:
70+
return data.read().decode(self.encoding)
71+
finally:
72+
data.close()
73+
74+
4675
class BytesDeserializer(BaseDeserializer):
4776
"""Deserialize a stream of bytes into a bytes object."""
4877

src/sagemaker/predictor.py

-29
Original file line numberDiff line numberDiff line change
@@ -597,35 +597,6 @@ def _row_to_csv(obj):
597597
return ",".join(obj)
598598

599599

600-
class StringDeserializer(object):
601-
"""Return the response as a decoded string.
602-
603-
Args:
604-
encoding (str): The string encoding to use (default=utf-8).
605-
accept (str): The Accept header to send to the server (optional).
606-
"""
607-
608-
def __init__(self, encoding="utf-8", accept=None):
609-
"""
610-
Args:
611-
encoding:
612-
accept:
613-
"""
614-
self.encoding = encoding
615-
self.accept = accept
616-
617-
def __call__(self, stream, content_type):
618-
"""
619-
Args:
620-
stream:
621-
content_type:
622-
"""
623-
try:
624-
return stream.read().decode(self.encoding)
625-
finally:
626-
stream.close()
627-
628-
629600
class StreamDeserializer(object):
630601
"""Returns the tuple of the response stream and the content-type of the response.
631602
It is the receivers responsibility to close the stream when they're done

tests/integ/test_multidatamodel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424

2525
from sagemaker import utils
2626
from sagemaker.amazon.randomcutforest import RandomCutForest
27+
from sagemaker.deserializers import StringDeserializer
2728
from sagemaker.multidatamodel import MultiDataModel
2829
from sagemaker.mxnet import MXNet
29-
from sagemaker.predictor import Predictor, StringDeserializer, npy_serializer
30+
from sagemaker.predictor import Predictor, npy_serializer
3031
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3132
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3233
from tests.integ.retry import retries
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import itertools
16+
17+
import pasta
18+
19+
from sagemaker.cli.compatibility.v2.modifiers import training_params
20+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
21+
22+
ESTIMATORS_TO_NAMESPACES = {
23+
"AlgorithmEstimator": ("sagemaker", "sagemaker.algorithm"),
24+
"AmazonAlgorithmEstimatorBase": ("sagemaker.amazon.amazon_estimator",),
25+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
26+
"Estimator": ("sagemaker.estimator",),
27+
"EstimatorBase": ("sagemaker.estimator",),
28+
"FactorizationMachines": ("sagemaker", "sagemaker.amazon.factorization_machines"),
29+
"Framework": ("sagemaker.estimator",),
30+
"IPInsights": ("sagemaker", "sagemaker.amazon.ipinsights"),
31+
"KMeans": ("sagemaker", "sagemaker.amazon.kmeans"),
32+
"KNN": ("sagemaker", "sagemaker.amazon.knn"),
33+
"LDA": ("sagemaker", "sagemaker.amazon.lda"),
34+
"LinearLearner": ("sagemaker", "sagemaker.amazon.linear_learner"),
35+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
36+
"NTM": ("sagemaker", "sagemaker.amazon.ntm"),
37+
"Object2Vec": ("sagemaker", "sagemaker.amazon.object2vec"),
38+
"PCA": ("sagemaker", "sagemaker.amazon.pca"),
39+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
40+
"RandomCutForest": ("sagemaker", "sagemaker.amazon.randomcutforest"),
41+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
42+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
43+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
44+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
45+
}
46+
47+
PARAMS_WITH_VALUES = (
48+
"train_instance_count=1",
49+
"train_instance_type='ml.c4.xlarge'",
50+
"train_max_run=8 * 60 * 60",
51+
"train_max_run_wait=1 * 60 * 60",
52+
"train_use_spot_instances=True",
53+
"train_volume_size=30",
54+
"train_volume_kms_key='key'",
55+
)
56+
57+
58+
def _estimators():
59+
for estimator, namespaces in ESTIMATORS_TO_NAMESPACES.items():
60+
yield estimator
61+
62+
for namespace in namespaces:
63+
yield ".".join((namespace, estimator))
64+
65+
66+
def test_node_should_be_modified():
67+
modifier = training_params.TrainPrefixRemover()
68+
69+
for estimator in _estimators():
70+
for param in PARAMS_WITH_VALUES:
71+
call = ast_call("{}({})".format(estimator, param))
72+
assert modifier.node_should_be_modified(call)
73+
74+
75+
def test_node_should_be_modified_no_params():
76+
modifier = training_params.TrainPrefixRemover()
77+
78+
for estimator in _estimators():
79+
call = ast_call("{}()".format(estimator))
80+
assert not modifier.node_should_be_modified(call)
81+
82+
83+
def test_node_should_be_modified_random_function_call():
84+
modifier = training_params.TrainPrefixRemover()
85+
assert not modifier.node_should_be_modified(ast_call("Session()"))
86+
87+
88+
def test_modify_node():
89+
modifier = training_params.TrainPrefixRemover()
90+
91+
for params in _parameter_combinations():
92+
node = ast_call("Estimator({})".format(params))
93+
modifier.modify_node(node)
94+
95+
expected = "Estimator({})".format(params).replace("train_", "")
96+
assert expected == pasta.dump(node)
97+
98+
99+
def _parameter_combinations():
100+
for subset_length in range(1, len(PARAMS_WITH_VALUES) + 1):
101+
for subset in itertools.combinations(PARAMS_WITH_VALUES, subset_length):
102+
yield ", ".join(subset)

tests/unit/sagemaker/test_deserializers.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,17 @@
1616

1717
import pytest
1818

19-
from sagemaker.deserializers import BytesDeserializer, CSVDeserializer
19+
from sagemaker.deserializers import StringDeserializer, BytesDeserializer, CSVDeserializer
2020

2121

22+
def test_string_deserializer():
23+
deserializer = StringDeserializer()
24+
25+
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
26+
27+
assert result == "[1, 2, 3]"
28+
29+
2230
def test_bytes_deserializer():
2331
deserializer = BytesDeserializer()
2432

tests/unit/test_predictor.py

-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
json_serializer,
2626
json_deserializer,
2727
csv_serializer,
28-
StringDeserializer,
2928
StreamDeserializer,
3029
numpy_deserializer,
3130
npy_serializer,
@@ -167,12 +166,6 @@ def test_json_deserializer_invalid_data():
167166
assert "column" in str(error)
168167

169168

170-
def test_string_deserializer():
171-
result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
172-
173-
assert result == "[1, 2, 3]"
174-
175-
176169
def test_stream_deserializer():
177170
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
178171
result = stream.read()

0 commit comments

Comments
 (0)