Skip to content

Commit cc2d047

Browse files
authored
breaking: Move _CsvDeserializer to sagemaker.deserializers and rename to CSVDeserializer (#1682)
1 parent b837dc2 commit cc2d047

File tree

5 files changed

+62
-44
lines changed

5 files changed

+62
-44
lines changed

src/sagemaker/deserializers.py

+33
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Implements methods for deserializing data returned from an inference endpoint."""
1414
from __future__ import absolute_import
1515

16+
import csv
17+
1618
import abc
1719
import codecs
1820
import io
@@ -96,6 +98,37 @@ def deserialize(self, data, content_type):
9698
data.close()
9799

98100

101+
class CSVDeserializer(BaseDeserializer):
102+
"""Deserialize a stream of bytes into a list of lists."""
103+
104+
ACCEPT = "text/csv"
105+
106+
def __init__(self, encoding="utf-8"):
107+
"""Initialize the string encoding.
108+
109+
Args:
110+
encoding (str): The string encoding to use (default: "utf-8").
111+
"""
112+
self.encoding = encoding
113+
114+
def deserialize(self, data, content_type):
115+
"""Deserialize data from an inference endpoint into a list of lists.
116+
117+
Args:
118+
data (botocore.response.StreamingBody): Data to be deserialized.
119+
content_type (str): The MIME type of the data.
120+
121+
Returns:
122+
list: The data deserialized into a list of lists representing the
123+
contents of a CSV file.
124+
"""
125+
try:
126+
decoded_string = data.read().decode(self.encoding)
127+
return list(csv.reader(decoded_string.splitlines()))
128+
finally:
129+
data.close()
130+
131+
99132
class StreamDeserializer(BaseDeserializer):
100133
"""Returns the data and content-type received from an inference endpoint.
101134

src/sagemaker/predictor.py

-26
Original file line numberDiff line numberDiff line change
@@ -597,32 +597,6 @@ def _row_to_csv(obj):
597597
return ",".join(obj)
598598

599599

600-
class _CsvDeserializer(object):
601-
"""Placeholder docstring"""
602-
603-
def __init__(self, encoding="utf-8"):
604-
"""
605-
Args:
606-
encoding:
607-
"""
608-
self.accept = CONTENT_TYPE_CSV
609-
self.encoding = encoding
610-
611-
def __call__(self, stream, content_type):
612-
"""
613-
Args:
614-
stream:
615-
content_type:
616-
"""
617-
try:
618-
return list(csv.reader(stream.read().decode(self.encoding).splitlines()))
619-
finally:
620-
stream.close()
621-
622-
623-
csv_deserializer = _CsvDeserializer()
624-
625-
626600
class _JsonSerializer(object):
627601
"""Placeholder docstring"""
628602

src/sagemaker/xgboost/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker.deserializers import CSVDeserializer
1920
from sagemaker.fw_utils import model_code_key_prefix
2021
from sagemaker.fw_registry import default_framework_uri
2122
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
23+
from sagemaker.predictor import Predictor, npy_serializer
2324
from sagemaker.xgboost.defaults import XGBOOST_NAME
2425

2526
logger = logging.getLogger("sagemaker")
@@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4243
chain.
4344
"""
4445
super(XGBoostPredictor, self).__init__(
45-
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
46+
endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer()
4647
)
4748

4849

tests/unit/sagemaker/test_deserializers.py

+26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.deserializers import (
2121
StringDeserializer,
2222
BytesDeserializer,
23+
CSVDeserializer,
2324
StreamDeserializer,
2425
NumpyDeserializer,
2526
)
@@ -41,6 +42,31 @@ def test_bytes_deserializer():
4142
assert result == b"[1, 2, 3]"
4243

4344

45+
@pytest.fixture
46+
def csv_deserializer():
47+
return CSVDeserializer()
48+
49+
50+
def test_csv_deserializer_single_element(csv_deserializer):
51+
result = csv_deserializer.deserialize(io.BytesIO(b"1"), "text/csv")
52+
assert result == [["1"]]
53+
54+
55+
def test_csv_deserializer_array(csv_deserializer):
56+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv")
57+
assert result == [["1", "2", "3"]]
58+
59+
60+
def test_csv_deserializer_2dimensional(csv_deserializer):
61+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
62+
assert result == [["1", "2", "3"], ["3", "4", "5"]]
63+
64+
65+
def test_csv_deserializer_posix_compliant(csv_deserializer):
66+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5\n"), "text/csv")
67+
assert result == [["1", "2", "3"], ["3", "4", "5"]]
68+
69+
4470
def test_stream_deserializer():
4571
deserializer = StreamDeserializer()
4672

tests/unit/test_predictor.py

-16
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
json_serializer,
2626
json_deserializer,
2727
csv_serializer,
28-
csv_deserializer,
2928
npy_serializer,
3029
)
3130
from tests.unit import DATA_DIR
@@ -146,21 +145,6 @@ def test_csv_serializer_csv_reader():
146145
assert result == validation_data
147146

148147

149-
def test_csv_deserializer_single_element():
150-
result = csv_deserializer(io.BytesIO(b"1"), "text/csv")
151-
assert result == [["1"]]
152-
153-
154-
def test_csv_deserializer_array():
155-
result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv")
156-
assert result == [["1", "2", "3"]]
157-
158-
159-
def test_csv_deserializer_2dimensional():
160-
result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
161-
assert result == [["1", "2", "3"], ["3", "4", "5"]]
162-
163-
164148
def test_json_deserializer_array():
165149
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")
166150

0 commit comments

Comments
 (0)