Skip to content

Commit 71b2055

Browse files
committed
feat: content_type for RecordIO de/serializers
Extend content_type and accept args to the RecordIO de/serializers hiding over in the sagemaker.amazon package that were missed in the initial commit.
1 parent 876a304 commit 71b2055

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/sagemaker/amazon/common.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,22 @@
2222

2323
from sagemaker.amazon.record_pb2 import Record
2424
from sagemaker.deprecations import deprecated_class
25-
from sagemaker.deserializers import BaseDeserializer
26-
from sagemaker.serializers import BaseSerializer
25+
from sagemaker.deserializers import SimpleBaseDeserializer
26+
from sagemaker.serializers import SimpleBaseSerializer
2727
from sagemaker.utils import DeferredError
2828

2929

30-
class RecordSerializer(BaseSerializer):
30+
class RecordSerializer(SimpleBaseSerializer):
3131
"""Serialize a NumPy array for an inference request."""
3232

33-
CONTENT_TYPE = "application/x-recordio-protobuf"
33+
def __init__(self, content_type="application/x-recordio-protobuf"):
34+
"""Initialize a ``RecordSerializer`` instance.
35+
36+
Args:
37+
content_type (str): The MIME type to signal to the inference endpoint when sending
38+
request data (default: "application/x-recordio-protobuf").
39+
"""
40+
super(RecordSerializer, self).__init__(content_type=content_type)
3441

3542
def serialize(self, data):
3643
"""Serialize a NumPy array into a buffer containing RecordIO records.
@@ -56,10 +63,18 @@ def serialize(self, data):
5663
return buffer
5764

5865

59-
class RecordDeserializer(BaseDeserializer):
66+
class RecordDeserializer(SimpleBaseDeserializer):
6067
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
6168

62-
ACCEPT = ("application/x-recordio-protobuf",)
69+
def __init__(self, accept="application/x-recordio-protobuf"):
70+
"""Initialize a ``RecordDeserializer`` instance.
71+
72+
Args:
73+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
74+
is expected from the inference endpoint (default:
75+
"application/x-recordio-protobuf").
76+
"""
77+
super(RecordDeserializer, self).__init__(accept=accept)
6378

6479
def deserialize(self, data, content_type):
6580
"""Deserialize RecordIO Protobuf data from an inference endpoint.

0 commit comments

Comments
 (0)