Skip to content

Commit 845f771

Browse files
authored
Merge branch 'zwei' into fix-pip-install-cmd
2 parents 1c5d3a2 + e117c76 commit 845f771

36 files changed

+145
-270
lines changed

src/sagemaker/chainer/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def create_model(
216216
role or self.role,
217217
entry_point or self._model_entry_point(),
218218
source_dir=(source_dir or self._model_source_dir()),
219-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
220219
container_log_level=self.container_log_level,
221220
code_location=self.code_location,
222221
py_version=self.py_version,

src/sagemaker/deserializers.py

+35
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323
import numpy as np
2424

25+
from sagemaker.utils import DeferredError
26+
27+
try:
28+
import pandas
29+
except ImportError as e:
30+
pandas = DeferredError(e)
31+
2532

2633
class BaseDeserializer(abc.ABC):
2734
"""Abstract base class for creation of new deserializers.
@@ -208,3 +215,31 @@ def deserialize(self, stream, content_type):
208215
return json.load(codecs.getreader("utf-8")(stream))
209216
finally:
210217
stream.close()
218+
219+
220+
class PandasDeserializer(BaseDeserializer):
221+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
222+
223+
ACCEPT = "text/csv"
224+
225+
def deserialize(self, stream, content_type):
226+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
227+
dataframe.
228+
229+
If the data is JSON, the data should be formatted in the 'columns' orient.
230+
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
231+
232+
Args:
233+
stream (botocore.response.StreamingBody): Data to be deserialized.
234+
content_type (str): The MIME type of the data.
235+
236+
Returns:
237+
pandas.DataFrame: The data deserialized into a pandas DataFrame.
238+
"""
239+
if content_type == "text/csv":
240+
return pandas.read_csv(stream)
241+
242+
if content_type == "application/json":
243+
return pandas.read_json(stream)
244+
245+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

src/sagemaker/estimator.py

-16
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import logging
1818
import os
1919
import uuid
20-
import warnings
2120
from abc import ABCMeta
2221
from abc import abstractmethod
2322

@@ -46,7 +45,6 @@
4645
from sagemaker.model import (
4746
SCRIPT_PARAM_NAME,
4847
DIR_PARAM_NAME,
49-
CLOUDWATCH_METRICS_PARAM_NAME,
5048
CONTAINER_LOG_LEVEL_PARAM_NAME,
5149
JOB_NAME_PARAM_NAME,
5250
SAGEMAKER_REGION_PARAM_NAME,
@@ -1433,7 +1431,6 @@ def __init__(
14331431
entry_point,
14341432
source_dir=None,
14351433
hyperparameters=None,
1436-
enable_cloudwatch_metrics=False,
14371434
container_log_level=logging.INFO,
14381435
code_location=None,
14391436
image_uri=None,
@@ -1491,9 +1488,6 @@ def __init__(
14911488
SageMaker. For convenience, this accepts other types for keys
14921489
and values, but ``str()`` will be called to convert them before
14931490
training.
1494-
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are
1495-
cloudwatch metrics emitted by all SageMaker training jobs. This
1496-
will be ignored for now and removed in a further release.
14971491
container_log_level (int): Log level to use within the container
14981492
(default: logging.INFO). Valid values are defined in the Python
14991493
logging module.
@@ -1624,12 +1618,6 @@ def __init__(
16241618
self.dependencies = dependencies or []
16251619
self.uploaded_code = None
16261620

1627-
if enable_cloudwatch_metrics:
1628-
warnings.warn(
1629-
"enable_cloudwatch_metrics is now deprecated and will be removed in the future.",
1630-
DeprecationWarning,
1631-
)
1632-
self.enable_cloudwatch_metrics = False
16331621
self.container_log_level = container_log_level
16341622
self.code_location = code_location
16351623
self.image_uri = image_uri
@@ -1687,7 +1675,6 @@ def _prepare_for_training(self, job_name=None):
16871675
# Modify hyperparameters in-place to point to the right code directory and script URIs
16881676
self._hyperparameters[DIR_PARAM_NAME] = code_dir
16891677
self._hyperparameters[SCRIPT_PARAM_NAME] = script
1690-
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
16911678
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
16921679
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
16931680
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
@@ -1798,9 +1785,6 @@ class constructor
17981785
init_params["hyperparameters"].get(SCRIPT_PARAM_NAME)
17991786
)
18001787
init_params["source_dir"] = json.loads(init_params["hyperparameters"].get(DIR_PARAM_NAME))
1801-
init_params["enable_cloudwatch_metrics"] = json.loads(
1802-
init_params["hyperparameters"].get(CLOUDWATCH_METRICS_PARAM_NAME)
1803-
)
18041788
init_params["container_log_level"] = json.loads(
18051789
init_params["hyperparameters"].get(CONTAINER_LOG_LEVEL_PARAM_NAME)
18061790
)

src/sagemaker/model.py

-7
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,6 @@ def delete_model(self):
599599

600600
SCRIPT_PARAM_NAME = "sagemaker_program"
601601
DIR_PARAM_NAME = "sagemaker_submit_directory"
602-
CLOUDWATCH_METRICS_PARAM_NAME = "sagemaker_enable_cloudwatch_metrics"
603602
CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level"
604603
JOB_NAME_PARAM_NAME = "sagemaker_job_name"
605604
MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers"
@@ -624,7 +623,6 @@ def __init__(
624623
predictor_cls=None,
625624
env=None,
626625
name=None,
627-
enable_cloudwatch_metrics=False,
628626
container_log_level=logging.INFO,
629627
code_location=None,
630628
sagemaker_session=None,
@@ -682,9 +680,6 @@ def __init__(
682680
when hosted in SageMaker (default: None).
683681
name (str): The model name. If None, a default model name will be
684682
selected on each ``deploy``.
685-
enable_cloudwatch_metrics (bool): Whether training and hosting
686-
containers will generate CloudWatch metrics under the
687-
AWS/SageMakerContainer namespace (default: False).
688683
container_log_level (int): Log level to use within the container
689684
(default: logging.INFO). Valid values are defined in the Python
690685
logging module.
@@ -792,7 +787,6 @@ def __init__(
792787
self.source_dir = source_dir
793788
self.dependencies = dependencies or []
794789
self.git_config = git_config
795-
self.enable_cloudwatch_metrics = enable_cloudwatch_metrics
796790
self.container_log_level = container_log_level
797791
if code_location:
798792
self.bucket, self.key_prefix = fw_utils.parse_s3_url(code_location)
@@ -890,7 +884,6 @@ def _framework_env_vars(self):
890884
return {
891885
SCRIPT_PARAM_NAME.upper(): script_name,
892886
DIR_PARAM_NAME.upper(): dir_name,
893-
CLOUDWATCH_METRICS_PARAM_NAME.upper(): str(self.enable_cloudwatch_metrics).lower(),
894887
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
895888
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
896889
}

src/sagemaker/mxnet/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def create_model(
225225
framework_version=self.framework_version,
226226
py_version=self.py_version,
227227
source_dir=(source_dir or self._model_source_dir()),
228-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
229228
container_log_level=self.container_log_level,
230229
code_location=self.code_location,
231230
model_server_workers=model_server_workers,

src/sagemaker/predictor.py

-93
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16-
from sagemaker.deserializers import BaseDeserializer
1716
from sagemaker.model_monitor import DataCaptureConfig
18-
from sagemaker.serializers import BaseSerializer
1917
from sagemaker.session import production_variant, Session
2018
from sagemaker.utils import name_from_base
2119

@@ -64,11 +62,6 @@ def __init__(
6462
accept (str): The invocation's "Accept", overriding any accept from
6563
the deserializer (default: None).
6664
"""
67-
if serializer is not None and not isinstance(serializer, BaseSerializer):
68-
serializer = LegacySerializer(serializer)
69-
if deserializer is not None and not isinstance(deserializer, BaseDeserializer):
70-
deserializer = LegacyDeserializer(deserializer)
71-
7265
self.endpoint_name = endpoint_name
7366
self.sagemaker_session = sagemaker_session or Session()
7467
self.serializer = serializer
@@ -115,8 +108,6 @@ def _handle_response(self, response):
115108
"""
116109
response_body = response["Body"]
117110
if self.deserializer is not None:
118-
if not isinstance(self.deserializer, BaseDeserializer):
119-
self.deserializer = LegacyDeserializer(self.deserializer)
120111
# It's the deserializer's responsibility to close the stream
121112
return self.deserializer.deserialize(response_body, response["ContentType"])
122113
data = response_body.read()
@@ -149,8 +140,6 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
149140
args["TargetVariant"] = target_variant
150141

151142
if self.serializer is not None:
152-
if not isinstance(self.serializer, BaseSerializer):
153-
self.serializer = LegacySerializer(self.serializer)
154143
data = self.serializer.serialize(data)
155144

156145
args["Body"] = data
@@ -403,85 +392,3 @@ def _get_model_names(self):
403392
)
404393
production_variants = endpoint_config["ProductionVariants"]
405394
return [d["ModelName"] for d in production_variants]
406-
407-
408-
class LegacySerializer(BaseSerializer):
409-
"""Wrapper that makes legacy serializers forward compatibile."""
410-
411-
def __init__(self, serializer):
412-
"""Initialize a ``LegacySerializer``.
413-
414-
Args:
415-
serializer (callable): A legacy serializer.
416-
"""
417-
self.serializer = serializer
418-
self.content_type = getattr(serializer, "content_type", None)
419-
420-
def __call__(self, *args, **kwargs):
421-
"""Wraps the call method of the legacy serializer.
422-
423-
Args:
424-
data (object): Data to be serialized.
425-
426-
Returns:
427-
object: Serialized data used for a request.
428-
"""
429-
return self.serializer(*args, **kwargs)
430-
431-
def serialize(self, data):
432-
"""Wraps the call method of the legacy serializer.
433-
434-
Args:
435-
data (object): Data to be serialized.
436-
437-
Returns:
438-
object: Serialized data used for a request.
439-
"""
440-
return self.serializer(data)
441-
442-
@property
443-
def CONTENT_TYPE(self):
444-
"""The MIME type of the data sent to the inference endpoint."""
445-
return self.content_type
446-
447-
448-
class LegacyDeserializer(BaseDeserializer):
449-
"""Wrapper that makes legacy deserializers forward compatibile."""
450-
451-
def __init__(self, deserializer):
452-
"""Initialize a ``LegacyDeserializer``.
453-
454-
Args:
455-
deserializer (callable): A legacy deserializer.
456-
"""
457-
self.deserializer = deserializer
458-
self.accept = getattr(deserializer, "accept", None)
459-
460-
def __call__(self, *args, **kwargs):
461-
"""Wraps the call method of the legacy deserializer.
462-
463-
Args:
464-
data (object): Data to be deserialized.
465-
content_type (str): The MIME type of the data.
466-
467-
Returns:
468-
object: The data deserialized into an object.
469-
"""
470-
return self.deserializer(*args, **kwargs)
471-
472-
def deserialize(self, data, content_type):
473-
"""Wraps the call method of the legacy deserializer.
474-
475-
Args:
476-
data (object): Data to be deserialized.
477-
content_type (str): The MIME type of the data.
478-
479-
Returns:
480-
object: The data deserialized into an object.
481-
"""
482-
return self.deserializer(data, content_type)
483-
484-
@property
485-
def ACCEPT(self):
486-
"""The content type that is expected from the inference endpoint."""
487-
return self.accept

src/sagemaker/pytorch/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def create_model(
179179
framework_version=self.framework_version,
180180
py_version=self.py_version,
181181
source_dir=(source_dir or self._model_source_dir()),
182-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
183182
container_log_level=self.container_log_level,
184183
code_location=self.code_location,
185184
model_server_workers=model_server_workers,

src/sagemaker/rl/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def create_model(
244244
source_dir=source_dir,
245245
code_location=self.code_location,
246246
dependencies=dependencies,
247-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
248247
)
249248
extended_args.update(base_args)
250249

src/sagemaker/serializers.py

+29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
import numpy as np
2222

23+
from sagemaker.utils import DeferredError
24+
25+
try:
26+
import scipy
27+
except ImportError as e:
28+
scipy = DeferredError(e)
29+
2330

2431
class BaseSerializer(abc.ABC):
2532
"""Abstract base class for creation of new serializers.
@@ -183,3 +190,25 @@ def serialize(self, data):
183190
return json.dumps(data.tolist())
184191

185192
return json.dumps(data)
193+
194+
195+
class SparseMatrixSerializer(BaseSerializer):
196+
"""Serialize a sparse matrix to a buffer using the .npz format."""
197+
198+
CONTENT_TYPE = "application/x-npz"
199+
200+
def serialize(self, data):
201+
"""Serialize a sparse matrix to a buffer using the .npz format.
202+
203+
Sparse matrices can be in the ``csc``, ``csr``, ``bsr``, ``dia`` or
204+
``coo`` formats.
205+
206+
Args:
207+
data (scipy.sparse.spmatrix): The sparse matrix to serialize.
208+
209+
Returns:
210+
io.BytesIO: A buffer containing the serialized sparse matrix.
211+
"""
212+
buffer = io.BytesIO()
213+
scipy.sparse.save_npz(buffer, data)
214+
return buffer.getvalue()

src/sagemaker/sklearn/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def create_model(
201201
role,
202202
entry_point or self._model_entry_point(),
203203
source_dir=(source_dir or self._model_source_dir()),
204-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
205204
container_log_level=self.container_log_level,
206205
code_location=self.code_location,
207206
py_version=self.py_version,

src/sagemaker/workflow/airflow.py

-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ def prepare_framework(estimator, s3_operations):
5959
]
6060
estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir
6161
estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script
62-
estimator._hyperparameters[
63-
sagemaker.model.CLOUDWATCH_METRICS_PARAM_NAME
64-
] = estimator.enable_cloudwatch_metrics
6562
estimator._hyperparameters[
6663
sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME
6764
] = estimator.container_log_level

src/sagemaker/xgboost/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def create_model(
163163
entry_point or self._model_entry_point(),
164164
framework_version=self.framework_version,
165165
source_dir=(source_dir or self._model_source_dir()),
166-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
167166
container_log_level=self.container_log_level,
168167
code_location=self.code_location,
169168
py_version=self.py_version,

tests/component/test_mxnet_estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_deploy(sagemaker_session, tf_version):
7575
ROLE,
7676
{
7777
"Environment": {
78-
"SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
7978
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
8079
"SAGEMAKER_SUBMIT_DIRECTORY": SOURCE_DIR,
8180
"SAGEMAKER_REGION": REGION,

tests/component/test_tf_estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def test_deploy(sagemaker_session, tf_version):
7474
ROLE,
7575
{
7676
"Environment": {
77-
"SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
7877
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
7978
"SAGEMAKER_SUBMIT_DIRECTORY": SOURCE_DIR,
8079
"SAGEMAKER_REQUIREMENTS": "",

0 commit comments

Comments
 (0)