Skip to content

Commit b5525a5

Browse files
author
Balaji Veeramani
committed
Merge branch 'add-csv-deserializer' of github.com:bveeramani/sagemaker-python-sdk into add-csv-deserializer
2 parents d62cc8f + 450e084 commit b5525a5

File tree

10 files changed

+32
-33
lines changed

10 files changed

+32
-33
lines changed

src/sagemaker/amazon/common.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222

2323
from sagemaker.amazon.record_pb2 import Record
24+
from sagemaker.deserializers import BaseDeserializer
2425
from sagemaker.utils import DeferredError
2526

2627

@@ -48,26 +49,24 @@ def __call__(self, array):
4849
return buf
4950

5051

51-
class record_deserializer(object):
52-
"""Placeholder docstring"""
52+
class RecordDeserializer(BaseDeserializer):
53+
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
5354

54-
def __init__(self, accept="application/x-recordio-protobuf"):
55-
"""
56-
Args:
57-
accept:
58-
"""
59-
self.accept = accept
55+
ACCEPT = "application/x-recordio-protobuf"
56+
57+
def deserialize(self, data, content_type):
58+
"""Deserialize RecordIO Protobuf data from an inference endpoint.
6059
61-
def __call__(self, stream, content_type):
62-
"""
6360
Args:
64-
stream:
65-
content_type:
61+
data (object): The protobuf message to deserialize.
62+
content_type (str): The MIME type of the data.
63+
Returns:
64+
list: A list of records.
6665
"""
6766
try:
68-
return read_records(stream)
67+
return read_records(data)
6968
finally:
70-
stream.close()
69+
data.close()
7170

7271

7372
def _write_feature_tensor(resolved_type, record, vector):

src/sagemaker/amazon/factorization_machines.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge
2020
from sagemaker.predictor import Predictor
@@ -290,7 +290,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
290290
endpoint_name,
291291
sagemaker_session,
292292
serializer=numpy_to_record_serializer(),
293-
deserializer=record_deserializer(),
293+
deserializer=RecordDeserializer(),
294294
)
295295

296296

src/sagemaker/amazon/kmeans.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge, le
2020
from sagemaker.predictor import Predictor
@@ -223,7 +223,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
223223
endpoint_name,
224224
sagemaker_session,
225225
serializer=numpy_to_record_serializer(),
226-
deserializer=record_deserializer(),
226+
deserializer=RecordDeserializer(),
227227
)
228228

229229

src/sagemaker/amazon/knn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, isin
2020
from sagemaker.predictor import Predictor
@@ -211,7 +211,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
211211
endpoint_name,
212212
sagemaker_session,
213213
serializer=numpy_to_record_serializer(),
214-
deserializer=record_deserializer(),
214+
deserializer=RecordDeserializer(),
215215
)
216216

217217

src/sagemaker/amazon/lda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt
2020
from sagemaker.predictor import Predictor
@@ -195,7 +195,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
195195
endpoint_name,
196196
sagemaker_session,
197197
serializer=numpy_to_record_serializer(),
198-
deserializer=record_deserializer(),
198+
deserializer=RecordDeserializer(),
199199
)
200200

201201

src/sagemaker/amazon/linear_learner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import isin, gt, lt, ge, le
2020
from sagemaker.predictor import Predictor
@@ -454,7 +454,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
454454
endpoint_name,
455455
sagemaker_session,
456456
serializer=numpy_to_record_serializer(),
457-
deserializer=record_deserializer(),
457+
deserializer=RecordDeserializer(),
458458
)
459459

460460

src/sagemaker/amazon/ntm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le, isin
2020
from sagemaker.predictor import Predictor
@@ -225,7 +225,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
225225
endpoint_name,
226226
sagemaker_session,
227227
serializer=numpy_to_record_serializer(),
228-
deserializer=record_deserializer(),
228+
deserializer=RecordDeserializer(),
229229
)
230230

231231

src/sagemaker/amazon/pca.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin
2020
from sagemaker.predictor import Predictor
@@ -207,7 +207,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
207207
endpoint_name,
208208
sagemaker_session,
209209
serializer=numpy_to_record_serializer(),
210-
deserializer=record_deserializer(),
210+
deserializer=RecordDeserializer(),
211211
)
212212

213213

src/sagemaker/amazon/randomcutforest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
17+
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le
2020
from sagemaker.predictor import Predictor
@@ -184,7 +184,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
184184
endpoint_name,
185185
sagemaker_session,
186186
serializer=numpy_to_record_serializer(),
187-
deserializer=record_deserializer(),
187+
deserializer=RecordDeserializer(),
188188
)
189189

190190

tests/unit/test_common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import itertools
1919
from scipy.sparse import coo_matrix
2020
from sagemaker.amazon.common import (
21-
record_deserializer,
21+
RecordDeserializer,
2222
write_numpy_to_dense_tensor,
2323
read_recordio,
2424
numpy_to_record_serializer,
@@ -51,8 +51,8 @@ def test_deserializer():
5151
array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
5252
s = numpy_to_record_serializer()
5353
buf = s(np.array(array_data))
54-
d = record_deserializer()
55-
for record, expected in zip(d(buf, "who cares"), array_data):
54+
d = RecordDeserializer()
55+
for record, expected in zip(d.deserialize(buf, "who cares"), array_data):
5656
assert record.features["values"].float64_tensor.values == expected
5757

5858

0 commit comments

Comments
 (0)