Skip to content

Commit 01b005a

Browse files
committed
Merge branch 'zwei' into xgboost-uri
2 parents 8d6fbb2 + 9aa708e commit 01b005a

File tree

12 files changed

+153
-209
lines changed

12 files changed

+153
-209
lines changed

doc/frameworks/tensorflow/deploying_tensorflow_serving.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ your input data to CSV format:
240240
241241
# create a Predictor with JSON serialization
242242
243-
predictor = Predictor('endpoint-name', serializer=sagemaker.predictor.csv_serializer)
243+
predictor = Predictor('endpoint-name', serializer=sagemaker.serializers.CSVSerializer())
244244
245245
# CSV-formatted string input
246246
input = '1.0,2.0,5.0\n1.0,2.0,5.0\n1.0,2.0,5.0'
@@ -256,7 +256,7 @@ your input data to CSV format:
256256
]
257257
}
258258
259-
You can also use python arrays or numpy arrays as input and let the `csv_serializer` object
259+
You can also use python arrays or numpy arrays as input and let the ``CSVSerializer`` object
260260
convert them to CSV, but the client-size CSV conversion is more sophisticated than the
261261
CSV parsing on the Endpoint, so if you encounter conversion problems, try using one of the
262262
JSON options instead.

doc/frameworks/tensorflow/using_tf.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ your input data to CSV format:
710710
711711
# create a Predictor with JSON serialization
712712
713-
predictor = Predictor('endpoint-name', serializer=sagemaker.predictor.csv_serializer)
713+
predictor = Predictor('endpoint-name', serializer=sagemaker.serializers.CSVSerializer())
714714
715715
# CSV-formatted string input
716716
input = '1.0,2.0,5.0\n1.0,2.0,5.0\n1.0,2.0,5.0'
@@ -726,7 +726,7 @@ your input data to CSV format:
726726
]
727727
}
728728
729-
You can also use python arrays or numpy arrays as input and let the `csv_serializer` object
729+
You can also use python arrays or numpy arrays as input and let the ``CSVSerializer`` object
730730
convert them to CSV, but the client-size CSV conversion is more sophisticated than the
731731
CSV parsing on the Endpoint, so if you encounter conversion problems, try using one of the
732732
JSON options instead.

src/sagemaker/amazon/ipinsights.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1818
from sagemaker.amazon.validation import ge, le
1919
from sagemaker.deserializers import JSONDeserializer
20-
from sagemaker.predictor import Predictor, csv_serializer
20+
from sagemaker.predictor import Predictor
2121
from sagemaker.model import Model
22+
from sagemaker.serializers import CSVSerializer
2223
from sagemaker.session import Session
2324
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2425

@@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
198199
super(IPInsightsPredictor, self).__init__(
199200
endpoint_name,
200201
sagemaker_session,
201-
serializer=csv_serializer,
202+
serializer=CSVSerializer(),
202203
deserializer=JSONDeserializer(),
203204
)
204205

src/sagemaker/predictor.py

-107
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16-
import csv
17-
from six import StringIO
18-
import numpy as np
19-
20-
from sagemaker.content_types import CONTENT_TYPE_CSV
2116
from sagemaker.deserializers import BaseDeserializer
2217
from sagemaker.model_monitor import DataCaptureConfig
2318
from sagemaker.serializers import BaseSerializer
@@ -490,105 +485,3 @@ def deserialize(self, data, content_type):
490485
def ACCEPT(self):
491486
"""The content type that is expected from the inference endpoint."""
492487
return self.accept
493-
494-
495-
class _CsvSerializer(object):
496-
"""Placeholder docstring"""
497-
498-
def __init__(self):
499-
"""Placeholder docstring"""
500-
self.content_type = CONTENT_TYPE_CSV
501-
502-
def __call__(self, data):
503-
"""Take data of various data formats and serialize them into CSV.
504-
505-
Args:
506-
data (object): Data to be serialized.
507-
508-
Returns:
509-
object: Sequence of bytes to be used for the request body.
510-
"""
511-
# For inputs which represent multiple "rows", the result should be newline-separated CSV
512-
# rows
513-
if _is_mutable_sequence_like(data) and len(data) > 0 and _is_sequence_like(data[0]):
514-
return "\n".join([_CsvSerializer._serialize_row(row) for row in data])
515-
return _CsvSerializer._serialize_row(data)
516-
517-
@staticmethod
518-
def _serialize_row(data):
519-
# Don't attempt to re-serialize a string
520-
"""
521-
Args:
522-
data:
523-
"""
524-
if isinstance(data, str):
525-
return data
526-
if isinstance(data, np.ndarray):
527-
data = np.ndarray.flatten(data)
528-
if hasattr(data, "__len__"):
529-
if len(data) == 0:
530-
raise ValueError("Cannot serialize empty array")
531-
return _csv_serialize_python_array(data)
532-
533-
# files and buffers
534-
if hasattr(data, "read"):
535-
return _csv_serialize_from_buffer(data)
536-
537-
raise ValueError("Unable to handle input format: ", type(data))
538-
539-
540-
def _csv_serialize_python_array(data):
541-
"""
542-
Args:
543-
data:
544-
"""
545-
return _csv_serialize_object(data)
546-
547-
548-
def _csv_serialize_from_buffer(buff):
549-
"""
550-
Args:
551-
buff:
552-
"""
553-
return buff.read()
554-
555-
556-
def _csv_serialize_object(data):
557-
"""
558-
Args:
559-
data:
560-
"""
561-
csv_buffer = StringIO()
562-
563-
csv_writer = csv.writer(csv_buffer, delimiter=",")
564-
csv_writer.writerow(data)
565-
return csv_buffer.getvalue().rstrip("\r\n")
566-
567-
568-
csv_serializer = _CsvSerializer()
569-
570-
571-
def _is_mutable_sequence_like(obj):
572-
"""
573-
Args:
574-
obj:
575-
"""
576-
return _is_sequence_like(obj) and hasattr(obj, "__setitem__")
577-
578-
579-
def _is_sequence_like(obj):
580-
"""
581-
Args:
582-
obj:
583-
"""
584-
return hasattr(obj, "__iter__") and hasattr(obj, "__getitem__")
585-
586-
587-
def _row_to_csv(obj):
588-
"""
589-
Args:
590-
obj:
591-
"""
592-
if isinstance(obj, str):
593-
return obj
594-
return ",".join(obj)

src/sagemaker/serializers.py

+57
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import csv
1718
import io
1819
import json
1920

@@ -44,6 +45,62 @@ def CONTENT_TYPE(self):
4445
"""The MIME type of the data sent to the inference endpoint."""
4546

4647

48+
class CSVSerializer(BaseSerializer):
49+
"""Searilize data of various formats to a CSV-formatted string."""
50+
51+
CONTENT_TYPE = "text/csv"
52+
53+
def serialize(self, data):
54+
"""Serialize data of various formats to a CSV-formatted string.
55+
56+
Args:
57+
data (object): Data to be serialized. Can be a NumPy array, list,
58+
file, or buffer.
59+
60+
Returns:
61+
str: The data serialized as a CSV-formatted string.
62+
"""
63+
if hasattr(data, "read"):
64+
return data.read()
65+
66+
is_mutable_sequence_like = self._is_sequence_like(data) and hasattr(data, "__setitem__")
67+
has_multiple_rows = len(data) > 0 and self._is_sequence_like(data[0])
68+
69+
if is_mutable_sequence_like and has_multiple_rows:
70+
return "\n".join([self._serialize_row(row) for row in data])
71+
72+
return self._serialize_row(data)
73+
74+
def _serialize_row(self, data):
75+
"""Serialize data as a CSV-formatted row.
76+
77+
Args:
78+
data (object): Data to be serialized in a row.
79+
80+
Returns:
81+
str: The data serialized as a CSV-formatted row.
82+
"""
83+
if isinstance(data, str):
84+
return data
85+
86+
if isinstance(data, np.ndarray):
87+
data = np.ndarray.flatten(data)
88+
89+
if hasattr(data, "__len__"):
90+
if len(data) == 0:
91+
raise ValueError("Cannot serialize empty array")
92+
csv_buffer = io.StringIO()
93+
csv_writer = csv.writer(csv_buffer, delimiter=",")
94+
csv_writer.writerow(data)
95+
return csv_buffer.getvalue().rstrip("\r\n")
96+
97+
raise ValueError("Unable to handle input format: ", type(data))
98+
99+
def _is_sequence_like(self, data):
100+
"""Returns true if obj is iterable and subscriptable."""
101+
return hasattr(data, "__iter__") and hasattr(data, "__getitem__")
102+
103+
47104
class NumpySerializer(BaseSerializer):
48105
"""Serialize data to a buffer using the .npy format."""
49106

src/sagemaker/sparkml/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker import Model, Predictor, Session
1717
from sagemaker.content_types import CONTENT_TYPE_CSV
1818
from sagemaker.fw_registry import registry
19-
from sagemaker.predictor import csv_serializer
19+
from sagemaker.serializers import CSVSerializer
2020

2121
framework_name = "sparkml-serving"
2222
repo_name = "sagemaker-sparkml-serving"
@@ -51,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5151
super(SparkMLPredictor, self).__init__(
5252
endpoint_name=endpoint_name,
5353
sagemaker_session=sagemaker_session,
54-
serializer=csv_serializer,
54+
serializer=CSVSerializer(),
5555
content_type=CONTENT_TYPE_CSV,
5656
)
5757

tests/integ/test_marketplace.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sagemaker
2323
import tests.integ
2424
from sagemaker import AlgorithmEstimator, ModelPackage
25+
from sagemaker.serializers import CSVSerializer
2526
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
2627
from sagemaker.utils import sagemaker_timestamp
2728
from sagemaker.utils import _aws_partition
@@ -136,10 +137,7 @@ def test_marketplace_attach(sagemaker_session, cpu_instance_type):
136137
training_job_name=training_job_name, sagemaker_session=sagemaker_session
137138
)
138139
predictor = estimator.deploy(
139-
1,
140-
cpu_instance_type,
141-
endpoint_name=endpoint_name,
142-
serializer=sagemaker.predictor.csv_serializer,
140+
1, cpu_instance_type, endpoint_name=endpoint_name, serializer=CSVSerializer()
143141
)
144142
shape = pandas.read_csv(os.path.join(data_path, "iris.csv"), header=None)
145143
a = [50 * i for i in range(3)]
@@ -165,7 +163,7 @@ def test_marketplace_model(sagemaker_session, cpu_instance_type):
165163
)
166164

167165
def predict_wrapper(endpoint, session):
168-
return sagemaker.Predictor(endpoint, session, serializer=sagemaker.predictor.csv_serializer)
166+
return sagemaker.Predictor(endpoint, session, serializer=CSVSerializer())
169167

170168
model = ModelPackage(
171169
role="SageMakerRole",

tests/integ/test_multi_variant_endpoint.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from sagemaker.content_types import CONTENT_TYPE_CSV
2626
from sagemaker.utils import unique_name_from_base
2727
from sagemaker.amazon.amazon_estimator import get_image_uri
28-
from sagemaker.predictor import csv_serializer, Predictor
28+
from sagemaker.predictor import Predictor
29+
from sagemaker.serializers import CSVSerializer
2930

3031

3132
import tests.integ
@@ -169,7 +170,7 @@ def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant
169170
predictor = Predictor(
170171
endpoint_name=multi_variant_endpoint.endpoint_name,
171172
sagemaker_session=sagemaker_session,
172-
serializer=csv_serializer,
173+
serializer=CSVSerializer(),
173174
content_type=CONTENT_TYPE_CSV,
174175
accept=CONTENT_TYPE_CSV,
175176
)
@@ -297,7 +298,7 @@ def test_predict_invocation_with_target_variant_local_mode(
297298
predictor = Predictor(
298299
endpoint_name=multi_variant_endpoint.endpoint_name,
299300
sagemaker_session=sagemaker_session,
300-
serializer=csv_serializer,
301+
serializer=CSVSerializer(),
301302
content_type=CONTENT_TYPE_CSV,
302303
accept=CONTENT_TYPE_CSV,
303304
)

tests/integ/test_tfs.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tests.integ
2525
import tests.integ.timeout
2626
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
27+
from sagemaker.serializers import CSVSerializer
2728

2829

2930
@pytest.fixture(scope="module")
@@ -236,9 +237,7 @@ def test_predict_csv(tfs_predictor):
236237
expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]}
237238

238239
predictor = TensorFlowPredictor(
239-
tfs_predictor.endpoint_name,
240-
tfs_predictor.sagemaker_session,
241-
serializer=sagemaker.predictor.csv_serializer,
240+
tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=CSVSerializer(),
242241
)
243242

244243
result = predictor.predict(input_data)

tests/unit/sagemaker/tensorflow/test_tfs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import mock
2020
import pytest
2121
from mock import Mock, patch
22-
from sagemaker.predictor import csv_serializer
22+
from sagemaker.serializers import CSVSerializer
2323
from sagemaker.tensorflow import TensorFlow
2424
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
2525

@@ -323,7 +323,7 @@ def test_predictor_jsons(sagemaker_session):
323323

324324

325325
def test_predictor_csv(sagemaker_session):
326-
predictor = TensorFlowPredictor("endpoint", sagemaker_session, serializer=csv_serializer)
326+
predictor = TensorFlowPredictor("endpoint", sagemaker_session, serializer=CSVSerializer())
327327

328328
mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session)
329329
result = predictor.predict([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -398,14 +398,14 @@ def test_predictor_regress(sagemaker_session):
398398

399399

400400
def test_predictor_regress_bad_content_type(sagemaker_session):
401-
predictor = TensorFlowPredictor("endpoint", sagemaker_session, csv_serializer)
401+
predictor = TensorFlowPredictor("endpoint", sagemaker_session, CSVSerializer())
402402

403403
with pytest.raises(ValueError):
404404
predictor.regress(REGRESS_INPUT)
405405

406406

407407
def test_predictor_classify_bad_content_type(sagemaker_session):
408-
predictor = TensorFlowPredictor("endpoint", sagemaker_session, csv_serializer)
408+
predictor = TensorFlowPredictor("endpoint", sagemaker_session, CSVSerializer())
409409

410410
with pytest.raises(ValueError):
411411
predictor.classify(CLASSIFY_INPUT)

0 commit comments

Comments
 (0)