Skip to content

Commit 5c64e6c

Browse files
authored
feature: Data Serializer (#2956)
1 parent 4ce6623 commit 5c64e6c

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

src/sagemaker/serializers.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import csv
1919
import io
2020
import json
21-
2221
import numpy as np
2322
from six import with_metaclass
2423

@@ -357,3 +356,38 @@ def serialize(self, data):
357356
return data.read()
358357

359358
raise ValueError("Unable to handle input format: %s" % type(data))
359+
360+
361+
class DataSerializer(SimpleBaseSerializer):
362+
"""Serialize data in any file by extracting raw bytes from the file."""
363+
364+
def __init__(self, content_type="file-path/raw-bytes"):
365+
"""Initialize a ``DataSerializer`` instance.
366+
367+
Args:
368+
content_type (str): The MIME type to signal to the inference endpoint when sending
369+
request data (default: "file-path/raw-bytes").
370+
"""
371+
super(DataSerializer, self).__init__(content_type=content_type)
372+
373+
def serialize(self, data):
374+
"""Serialize file data to a raw bytes.
375+
376+
Args:
377+
data (object): Data to be serialized. The data can be a string
378+
representing file-path or the raw bytes from a file.
379+
Returns:
380+
raw-bytes: The data serialized as raw-bytes from the input.
381+
"""
382+
if isinstance(data, str):
383+
try:
384+
dataFile = open(data, "rb")
385+
dataFileInfo = dataFile.read()
386+
dataFile.close()
387+
return dataFileInfo
388+
except Exception as e:
389+
raise ValueError(f"Could not open/read file: {data}. {e}")
390+
if isinstance(data, bytes):
391+
return data
392+
393+
raise ValueError(f"Object of type {type(data)} is not Data serializable.")

tests/data/cuteCat.raw

6.43 KB
Binary file not shown.

tests/unit/sagemaker/test_serializers.py

+24
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SparseMatrixSerializer,
2929
JSONLinesSerializer,
3030
LibSVMSerializer,
31+
DataSerializer,
3132
)
3233
from tests.unit import DATA_DIR
3334

@@ -331,3 +332,26 @@ def test_libsvm_serializer_file_like(libsvm_serializer):
331332
libsvm_file.seek(0)
332333
result = libsvm_serializer.serialize(libsvm_file)
333334
assert result == validation_data
335+
336+
337+
@pytest.fixture
338+
def data_serializer():
339+
return DataSerializer()
340+
341+
342+
def test_data_serializer_raw(data_serializer):
343+
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
344+
with open(input_image_file_path, "rb") as image:
345+
input_image = image.read()
346+
input_image_data = data_serializer.serialize(input_image)
347+
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
348+
validation_image_data = open(validation_image_file_path, "rb").read()
349+
assert input_image_data == validation_image_data
350+
351+
352+
def test_data_serializer_file_like(data_serializer):
353+
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
354+
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
355+
input_image_data = data_serializer.serialize(input_image_file_path)
356+
validation_image_data = open(validation_image_file_path, "rb").read()
357+
assert input_image_data == validation_image_data

0 commit comments

Comments
 (0)