22
22
23
23
from sagemaker .amazon .record_pb2 import Record
24
24
from sagemaker .deprecations import deprecated_class
25
- from sagemaker .deserializers import BaseDeserializer
26
- from sagemaker .serializers import BaseSerializer
25
+ from sagemaker .deserializers import SimpleBaseDeserializer
26
+ from sagemaker .serializers import SimpleBaseSerializer
27
27
from sagemaker .utils import DeferredError
28
28
29
29
30
- class RecordSerializer (BaseSerializer ):
30
+ class RecordSerializer (SimpleBaseSerializer ):
31
31
"""Serialize a NumPy array for an inference request."""
32
32
33
- CONTENT_TYPE = "application/x-recordio-protobuf"
33
+ def __init__ (self , content_type = "application/x-recordio-protobuf" ):
34
+ """Initialize a ``RecordSerializer`` instance.
35
+
36
+ Args:
37
+ content_type (str): The MIME type to signal to the inference endpoint when sending
38
+ request data (default: "application/x-recordio-protobuf").
39
+ """
40
+ super (RecordSerializer , self ).__init__ (content_type = content_type )
34
41
35
42
def serialize (self , data ):
36
43
"""Serialize a NumPy array into a buffer containing RecordIO records.
@@ -56,10 +63,18 @@ def serialize(self, data):
56
63
return buffer
57
64
58
65
59
- class RecordDeserializer (BaseDeserializer ):
66
+ class RecordDeserializer (SimpleBaseDeserializer ):
60
67
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
61
68
62
- ACCEPT = ("application/x-recordio-protobuf" ,)
69
+ def __init__ (self , accept = "application/x-recordio-protobuf" ):
70
+ """Initialize a ``RecordDeserializer`` instance.
71
+
72
+ Args:
73
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
74
+ is expected from the inference endpoint (default:
75
+ "application/x-recordio-protobuf").
76
+ """
77
+ super (RecordDeserializer , self ).__init__ (accept = accept )
63
78
64
79
def deserialize (self , data , content_type ):
65
80
"""Deserialize RecordIO Protobuf data from an inference endpoint.
0 commit comments