Skip to content

Commit 62263b7

Browse files
authored
Merge branch 'zwei' into add-csv-deserializer
2 parents 6aabe77 + b837dc2 commit 62263b7

File tree

10 files changed

+47
-41
lines changed

10 files changed

+47
-41
lines changed

src/sagemaker/amazon/common.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,37 @@
2222

2323
from sagemaker.amazon.record_pb2 import Record
2424
from sagemaker.deserializers import BaseDeserializer
25+
from sagemaker.serializers import BaseSerializer
2526
from sagemaker.utils import DeferredError
2627

2728

28-
class numpy_to_record_serializer(object):
29-
"""Placeholder docstring"""
29+
class RecordSerializer(BaseSerializer):
30+
"""Serialize a NumPy array for an inference request."""
3031

31-
def __init__(self, content_type="application/x-recordio-protobuf"):
32-
"""
33-
Args:
34-
content_type:
35-
"""
36-
self.content_type = content_type
32+
CONTENT_TYPE = "application/x-recordio-protobuf"
33+
34+
def serialize(self, data):
35+
"""Serialize a NumPy array into a buffer containing RecordIO records.
3736
38-
def __call__(self, array):
39-
"""
4037
Args:
41-
array:
38+
data (numpy.ndarray): The data to serialize.
39+
40+
Returns:
41+
io.BytesIO: A buffer containing the data serialized as records.
4242
"""
43-
if len(array.shape) == 1:
44-
array = array.reshape(1, array.shape[0])
45-
assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array"
46-
buf = io.BytesIO()
47-
write_numpy_to_dense_tensor(buf, array)
48-
buf.seek(0)
49-
return buf
43+
if len(data.shape) == 1:
44+
data = data.reshape(1, data.shape[0])
45+
46+
if len(data.shape) != 2:
47+
raise ValueError(
48+
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
49+
)
50+
51+
buffer = io.BytesIO()
52+
write_numpy_to_dense_tensor(buffer, data)
53+
buffer.seek(0)
54+
55+
return buffer
5056

5157

5258
class RecordDeserializer(BaseDeserializer):

src/sagemaker/amazon/factorization_machines.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge
2020
from sagemaker.predictor import Predictor
@@ -289,7 +289,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
289289
super(FactorizationMachinesPredictor, self).__init__(
290290
endpoint_name,
291291
sagemaker_session,
292-
serializer=numpy_to_record_serializer(),
292+
serializer=RecordSerializer(),
293293
deserializer=RecordDeserializer(),
294294
)
295295

src/sagemaker/amazon/kmeans.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge, le
2020
from sagemaker.predictor import Predictor
@@ -222,7 +222,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
222222
super(KMeansPredictor, self).__init__(
223223
endpoint_name,
224224
sagemaker_session,
225-
serializer=numpy_to_record_serializer(),
225+
serializer=RecordSerializer(),
226226
deserializer=RecordDeserializer(),
227227
)
228228

src/sagemaker/amazon/knn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, isin
2020
from sagemaker.predictor import Predictor
@@ -210,7 +210,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
210210
super(KNNPredictor, self).__init__(
211211
endpoint_name,
212212
sagemaker_session,
213-
serializer=numpy_to_record_serializer(),
213+
serializer=RecordSerializer(),
214214
deserializer=RecordDeserializer(),
215215
)
216216

src/sagemaker/amazon/lda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt
2020
from sagemaker.predictor import Predictor
@@ -194,7 +194,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
194194
super(LDAPredictor, self).__init__(
195195
endpoint_name,
196196
sagemaker_session,
197-
serializer=numpy_to_record_serializer(),
197+
serializer=RecordSerializer(),
198198
deserializer=RecordDeserializer(),
199199
)
200200

src/sagemaker/amazon/linear_learner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import isin, gt, lt, ge, le
2020
from sagemaker.predictor import Predictor
@@ -453,7 +453,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
453453
super(LinearLearnerPredictor, self).__init__(
454454
endpoint_name,
455455
sagemaker_session,
456-
serializer=numpy_to_record_serializer(),
456+
serializer=RecordSerializer(),
457457
deserializer=RecordDeserializer(),
458458
)
459459

src/sagemaker/amazon/ntm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le, isin
2020
from sagemaker.predictor import Predictor
@@ -224,7 +224,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
224224
super(NTMPredictor, self).__init__(
225225
endpoint_name,
226226
sagemaker_session,
227-
serializer=numpy_to_record_serializer(),
227+
serializer=RecordSerializer(),
228228
deserializer=RecordDeserializer(),
229229
)
230230

src/sagemaker/amazon/pca.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin
2020
from sagemaker.predictor import Predictor
@@ -206,7 +206,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
206206
super(PCAPredictor, self).__init__(
207207
endpoint_name,
208208
sagemaker_session,
209-
serializer=numpy_to_record_serializer(),
209+
serializer=RecordSerializer(),
210210
deserializer=RecordDeserializer(),
211211
)
212212

src/sagemaker/amazon/randomcutforest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le
2020
from sagemaker.predictor import Predictor
@@ -183,7 +183,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
183183
super(RandomCutForestPredictor, self).__init__(
184184
endpoint_name,
185185
sagemaker_session,
186-
serializer=numpy_to_record_serializer(),
186+
serializer=RecordSerializer(),
187187
deserializer=RecordDeserializer(),
188188
)
189189

tests/unit/test_common.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,26 @@
2121
RecordDeserializer,
2222
write_numpy_to_dense_tensor,
2323
read_recordio,
24-
numpy_to_record_serializer,
24+
RecordSerializer,
2525
write_spmatrix_to_sparse_tensor,
2626
)
2727
from sagemaker.amazon.record_pb2 import Record
2828

2929

3030
def test_serializer():
31-
s = numpy_to_record_serializer()
31+
s = RecordSerializer()
3232
array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
33-
buf = s(np.array(array_data))
33+
buf = s.serialize(np.array(array_data))
3434
for record_data, expected in zip(read_recordio(buf), array_data):
3535
record = Record()
3636
record.ParseFromString(record_data)
3737
assert record.features["values"].float64_tensor.values == expected
3838

3939

4040
def test_serializer_accepts_one_dimensional_array():
41-
s = numpy_to_record_serializer()
41+
s = RecordSerializer()
4242
array_data = [1.0, 2.0, 3.0]
43-
buf = s(np.array(array_data))
43+
buf = s.serialize(np.array(array_data))
4444
record_data = next(read_recordio(buf))
4545
record = Record()
4646
record.ParseFromString(record_data)
@@ -49,8 +49,8 @@ def test_serializer_accepts_one_dimensional_array():
4949

5050
def test_deserializer():
5151
array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
52-
s = numpy_to_record_serializer()
53-
buf = s(np.array(array_data))
52+
s = RecordSerializer()
53+
buf = s.serialize(np.array(array_data))
5454
d = RecordDeserializer()
5555
for record, expected in zip(d.deserialize(buf, "who cares"), array_data):
5656
assert record.features["values"].float64_tensor.values == expected

0 commit comments

Comments
 (0)