diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 727912a33c..a5cc51239c 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -18,7 +18,6 @@ import csv import io import json - import numpy as np from six import with_metaclass @@ -357,3 +356,38 @@ def serialize(self, data): return data.read() raise ValueError("Unable to handle input format: %s" % type(data)) + + +class DataSerializer(SimpleBaseSerializer): + """Serialize data in any file by extracting raw bytes from the file.""" + + def __init__(self, content_type="file-path/raw-bytes"): + """Initialize a ``DataSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "file-path/raw-bytes"). + """ + super(DataSerializer, self).__init__(content_type=content_type) + + def serialize(self, data): + """Serialize file data to a raw bytes. + + Args: + data (object): Data to be serialized. The data can be a string + representing file-path or the raw bytes from a file. + Returns: + raw-bytes: The data serialized as raw-bytes from the input. + """ + if isinstance(data, str): + try: + dataFile = open(data, "rb") + dataFileInfo = dataFile.read() + dataFile.close() + return dataFileInfo + except Exception as e: + raise ValueError(f"Could not open/read file: {data}. {e}") + if isinstance(data, bytes): + return data + + raise ValueError(f"Object of type {type(data)} is not Data serializable.") diff --git a/tests/data/cuteCat.raw b/tests/data/cuteCat.raw new file mode 100644 index 0000000000..6e89b9d78f Binary files /dev/null and b/tests/data/cuteCat.raw differ diff --git a/tests/unit/sagemaker/test_serializers.py b/tests/unit/sagemaker/test_serializers.py index d2e4b7ce46..6b70c600ca 100644 --- a/tests/unit/sagemaker/test_serializers.py +++ b/tests/unit/sagemaker/test_serializers.py @@ -28,6 +28,7 @@ SparseMatrixSerializer, JSONLinesSerializer, LibSVMSerializer, + DataSerializer, ) from tests.unit import DATA_DIR @@ -331,3 +332,26 @@ def test_libsvm_serializer_file_like(libsvm_serializer): libsvm_file.seek(0) result = libsvm_serializer.serialize(libsvm_file) assert result == validation_data + + +@pytest.fixture +def data_serializer(): + return DataSerializer() + + +def test_data_serializer_raw(data_serializer): + input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg") + with open(input_image_file_path, "rb") as image: + input_image = image.read() + input_image_data = data_serializer.serialize(input_image) + validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw") + validation_image_data = open(validation_image_file_path, "rb").read() + assert input_image_data == validation_image_data + + +def test_data_serializer_file_like(data_serializer): + input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg") + validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw") + input_image_data = data_serializer.serialize(input_image_file_path) + validation_image_data = open(validation_image_file_path, "rb").read() + assert input_image_data == validation_image_data