Skip to content

Commit 343ddd2

Browse files
committed
fix: jumpstart predictor class support following integ tests
1 parent 9fffdb8 commit 343ddd2

File tree

5 files changed

+24
-16
lines changed

5 files changed

+24
-16
lines changed

src/sagemaker/jumpstart/constants.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
BaseSerializer,
2727
CSVSerializer,
2828
DataSerializer,
29+
IdentitySerializer,
2930
JSONSerializer,
30-
SimpleBaseSerializer,
3131
)
3232

3333

@@ -162,12 +162,12 @@
162162
}
163163

164164
SERIALIZER_TYPE_TO_CLASS_MAP: Dict[SerializerType, BaseSerializer] = {
165-
SerializerType.RAW_BYTES: DataSerializer,
166-
SerializerType.JSON: JSONSerializer,
167-
SerializerType.TEXT: SimpleBaseSerializer,
168-
SerializerType.CSV: CSVSerializer,
165+
SerializerType.RAW_BYTES: DataSerializer(),
166+
SerializerType.JSON: JSONSerializer(),
167+
SerializerType.TEXT: IdentitySerializer(),
168+
SerializerType.CSV: CSVSerializer(),
169169
}
170170

171171
DESERIALIZER_TYPE_TO_CLASS_MAP: Dict[DeserializerType, BaseDeserializer] = {
172-
DeserializerType.JSON: JSONDeserializer,
172+
DeserializerType.JSON: JSONDeserializer(),
173173
}

src/sagemaker/predictor.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,14 @@ def _create_request_args(
182182
args["EndpointName"] = self.endpoint_name
183183

184184
if "ContentType" not in args:
185-
args["ContentType"] = self.content_type
185+
args["ContentType"] = (
186+
self.content_type
187+
if isinstance(self.content_type, str)
188+
else ", ".join(self.content_type)
189+
)
186190

187191
if "Accept" not in args:
188-
args["Accept"] = ", ".join(self.accept)
192+
args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
189193

190194
if target_model:
191195
args["TargetModel"] = target_model

src/sagemaker/serializers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io
2121
import json
2222
import numpy as np
23+
from pandas import DataFrame
2324
from six import with_metaclass
2425

2526
from sagemaker.utils import DeferredError
@@ -100,14 +101,17 @@ def serialize(self, data):
100101
101102
Args:
102103
data (object): Data to be serialized. Can be a NumPy array, list,
103-
file, or buffer.
104+
file, Pandas DataFrame, or buffer.
104105
105106
Returns:
106107
str: The data serialized as a CSV-formatted string.
107108
"""
108109
if hasattr(data, "read"):
109110
return data.read()
110111

112+
if isinstance(data, DataFrame):
113+
return data.to_csv(header=False, index=False)
114+
111115
is_mutable_sequence_like = self._is_sequence_like(data) and hasattr(data, "__setitem__")
112116
has_multiple_rows = len(data) > 0 and self._is_sequence_like(data[0])
113117

tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from sagemaker.jumpstart.constants import INFERENCE_ENTRY_POINT_SCRIPT_NAME
1818
from sagemaker.jumpstart.artifacts import _retrieve_kwargs
1919
from sagemaker.jumpstart.enums import EnvVariableUseCase, KwargUseCase
20+
from sagemaker.jumpstart.predictor import JumpStartPredictor
2021
from sagemaker.model import Model
2122
from tests.integ.sagemaker.jumpstart.constants import (
2223
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
2324
JUMPSTART_TAG,
2425
InferenceTabularDataname,
2526
)
2627
from tests.integ.sagemaker.jumpstart.utils import (
27-
EndpointInvoker,
2828
download_inference_assets,
2929
get_sm_session,
3030
get_tabular_data,
@@ -91,13 +91,13 @@ def test_jumpstart_inference_model_class(setup):
9191
**deploy_kwargs,
9292
)
9393

94-
endpoint_invoker = EndpointInvoker(
95-
endpoint_name=model.endpoint_name,
94+
predictor = JumpStartPredictor(
95+
endpoint_name=model.endpoint_name, model_id=model_id, model_version=model_version
9696
)
9797

9898
download_inference_assets()
9999
ground_truth_label, features = get_tabular_data(InferenceTabularDataname.MULTICLASS)
100100

101-
response = endpoint_invoker.invoke_tabular_endpoint(features)
101+
response = predictor.predict(features)
102102

103103
assert response is not None

tests/unit/sagemaker/jumpstart/test_predictor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
10-
from sagemaker.serializers import SimpleBaseSerializer
10+
from sagemaker.serializers import IdentitySerializer
1111
from tests.unit.sagemaker.jumpstart.utils import (
1212
get_special_model_spec,
1313
)
@@ -28,7 +28,7 @@ def test_list_jumpstart_scripts(
2828
predictor = JumpStartPredictor(endpoint_name="blah", model_id=model_id)
2929

3030
assert predictor.content_type == MIMEType.X_TEXT
31-
assert predictor.serializer == SimpleBaseSerializer
31+
assert isinstance(predictor.serializer, IdentitySerializer)
3232

33-
assert predictor.deserializer == JSONDeserializer
33+
assert isinstance(predictor.deserializer, JSONDeserializer)
3434
assert predictor.accept == MIMEType.JSON

0 commit comments

Comments
 (0)