diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 6e68012c90..d69d8dacb4 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -22,31 +22,37 @@ from sagemaker.amazon.record_pb2 import Record from sagemaker.deserializers import BaseDeserializer +from sagemaker.serializers import BaseSerializer from sagemaker.utils import DeferredError -class numpy_to_record_serializer(object): - """Placeholder docstring""" +class RecordSerializer(BaseSerializer): + """Serialize a NumPy array for an inference request.""" - def __init__(self, content_type="application/x-recordio-protobuf"): - """ - Args: - content_type: - """ - self.content_type = content_type + CONTENT_TYPE = "application/x-recordio-protobuf" + + def serialize(self, data): + """Serialize a NumPy array into a buffer containing RecordIO records. - def __call__(self, array): - """ Args: - array: + data (numpy.ndarray): The data to serialize. + + Returns: + io.BytesIO: A buffer containing the data serialized as records. """ - if len(array.shape) == 1: - array = array.reshape(1, array.shape[0]) - assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array" - buf = io.BytesIO() - write_numpy_to_dense_tensor(buf, array) - buf.seek(0) - return buf + if len(data.shape) == 1: + data = data.reshape(1, data.shape[0]) + + if len(data.shape) != 2: + raise ValueError( + "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) + ) + + buffer = io.BytesIO() + write_numpy_to_dense_tensor(buffer, data) + buffer.seek(0) + + return buffer class RecordDeserializer(BaseDeserializer): diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 971122a75b..a820ac2dc7 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge from sagemaker.predictor import Predictor @@ -289,7 +289,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(FactorizationMachinesPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index a50b350610..231e9e8247 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge, le from sagemaker.predictor import Predictor @@ -222,7 +222,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(KMeansPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 89f8daef27..1113ca9843 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, isin from sagemaker.predictor import Predictor @@ -210,7 +210,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(KNNPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index b7d697a3ed..bf20e88e3a 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt from sagemaker.predictor import Predictor @@ -194,7 +194,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(LDAPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 5d756db6d5..382491a1df 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import isin, gt, lt, ge, le from sagemaker.predictor import Predictor @@ -453,7 +453,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(LinearLearnerPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 1b22fb1902..4584dabfff 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le, isin from sagemaker.predictor import Predictor @@ -224,7 +224,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(NTMPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 0ab64432be..5b3ad3c5a5 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin from sagemaker.predictor import Predictor @@ -206,7 +206,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(PCAPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index 9963b4ecc1..15c13c7441 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry -from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer +from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le from sagemaker.predictor import Predictor @@ -183,7 +183,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): super(RandomCutForestPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=numpy_to_record_serializer(), + serializer=RecordSerializer(), deserializer=RecordDeserializer(), ) diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 516babbdc8..9031d15003 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -21,16 +21,16 @@ RecordDeserializer, write_numpy_to_dense_tensor, read_recordio, - numpy_to_record_serializer, + RecordSerializer, write_spmatrix_to_sparse_tensor, ) from sagemaker.amazon.record_pb2 import Record def test_serializer(): - s = numpy_to_record_serializer() + s = RecordSerializer() array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] - buf = s(np.array(array_data)) + buf = s.serialize(np.array(array_data)) for record_data, expected in zip(read_recordio(buf), array_data): record = Record() record.ParseFromString(record_data) @@ -38,9 +38,9 @@ def test_serializer(): def test_serializer_accepts_one_dimensional_array(): - s = numpy_to_record_serializer() + s = RecordSerializer() array_data = [1.0, 2.0, 3.0] - buf = s(np.array(array_data)) + buf = s.serialize(np.array(array_data)) record_data = next(read_recordio(buf)) record = Record() record.ParseFromString(record_data) @@ -49,8 +49,8 @@ def test_serializer_accepts_one_dimensional_array(): def test_deserializer(): array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] - s = numpy_to_record_serializer() - buf = s(np.array(array_data)) + s = RecordSerializer() + buf = s.serialize(np.array(array_data)) d = RecordDeserializer() for record, expected in zip(d.deserialize(buf, "who cares"), array_data): assert record.features["values"].float64_tensor.values == expected