File tree 5 files changed +24
-16
lines changed
integ/sagemaker/jumpstart/script_mode_class
5 files changed +24
-16
lines changed Original file line number Diff line number Diff line change 26
26
BaseSerializer ,
27
27
CSVSerializer ,
28
28
DataSerializer ,
29
+ IdentitySerializer ,
29
30
JSONSerializer ,
30
- SimpleBaseSerializer ,
31
31
)
32
32
33
33
162
162
}
163
163
164
164
SERIALIZER_TYPE_TO_CLASS_MAP : Dict [SerializerType , BaseSerializer ] = {
165
- SerializerType .RAW_BYTES : DataSerializer ,
166
- SerializerType .JSON : JSONSerializer ,
167
- SerializerType .TEXT : SimpleBaseSerializer ,
168
- SerializerType .CSV : CSVSerializer ,
165
+ SerializerType .RAW_BYTES : DataSerializer () ,
166
+ SerializerType .JSON : JSONSerializer () ,
167
+ SerializerType .TEXT : IdentitySerializer () ,
168
+ SerializerType .CSV : CSVSerializer () ,
169
169
}
170
170
171
171
DESERIALIZER_TYPE_TO_CLASS_MAP : Dict [DeserializerType , BaseDeserializer ] = {
172
- DeserializerType .JSON : JSONDeserializer ,
172
+ DeserializerType .JSON : JSONDeserializer () ,
173
173
}
Original file line number Diff line number Diff line change @@ -182,10 +182,14 @@ def _create_request_args(
182
182
args ["EndpointName" ] = self .endpoint_name
183
183
184
184
if "ContentType" not in args :
185
- args ["ContentType" ] = self .content_type
185
+ args ["ContentType" ] = (
186
+ self .content_type
187
+ if isinstance (self .content_type , str )
188
+ else ", " .join (self .content_type )
189
+ )
186
190
187
191
if "Accept" not in args :
188
- args ["Accept" ] = ", " .join (self .accept )
192
+ args ["Accept" ] = self . accept if isinstance ( self . accept , str ) else ", " .join (self .accept )
189
193
190
194
if target_model :
191
195
args ["TargetModel" ] = target_model
Original file line number Diff line number Diff line change 20
20
import io
21
21
import json
22
22
import numpy as np
23
+ from pandas import DataFrame
23
24
from six import with_metaclass
24
25
25
26
from sagemaker .utils import DeferredError
@@ -100,14 +101,17 @@ def serialize(self, data):
100
101
101
102
Args:
102
103
data (object): Data to be serialized. Can be a NumPy array, list,
103
- file, or buffer.
104
+ file, Pandas DataFrame, or buffer.
104
105
105
106
Returns:
106
107
str: The data serialized as a CSV-formatted string.
107
108
"""
108
109
if hasattr (data , "read" ):
109
110
return data .read ()
110
111
112
+ if isinstance (data , DataFrame ):
113
+ return data .to_csv (header = False , index = False )
114
+
111
115
is_mutable_sequence_like = self ._is_sequence_like (data ) and hasattr (data , "__setitem__" )
112
116
has_multiple_rows = len (data ) > 0 and self ._is_sequence_like (data [0 ])
113
117
Original file line number Diff line number Diff line change 17
17
from sagemaker .jumpstart .constants import INFERENCE_ENTRY_POINT_SCRIPT_NAME
18
18
from sagemaker .jumpstart .artifacts import _retrieve_kwargs
19
19
from sagemaker .jumpstart .enums import EnvVariableUseCase , KwargUseCase
20
+ from sagemaker .jumpstart .predictor import JumpStartPredictor
20
21
from sagemaker .model import Model
21
22
from tests .integ .sagemaker .jumpstart .constants import (
22
23
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ,
23
24
JUMPSTART_TAG ,
24
25
InferenceTabularDataname ,
25
26
)
26
27
from tests .integ .sagemaker .jumpstart .utils import (
27
- EndpointInvoker ,
28
28
download_inference_assets ,
29
29
get_sm_session ,
30
30
get_tabular_data ,
@@ -91,13 +91,13 @@ def test_jumpstart_inference_model_class(setup):
91
91
** deploy_kwargs ,
92
92
)
93
93
94
- endpoint_invoker = EndpointInvoker (
95
- endpoint_name = model .endpoint_name ,
94
+ predictor = JumpStartPredictor (
95
+ endpoint_name = model .endpoint_name , model_id = model_id , model_version = model_version
96
96
)
97
97
98
98
download_inference_assets ()
99
99
ground_truth_label , features = get_tabular_data (InferenceTabularDataname .MULTICLASS )
100
100
101
- response = endpoint_invoker . invoke_tabular_endpoint (features )
101
+ response = predictor . predict (features )
102
102
103
103
assert response is not None
Original file line number Diff line number Diff line change 7
7
8
8
9
9
from sagemaker .jumpstart .utils import verify_model_region_and_return_specs
10
- from sagemaker .serializers import SimpleBaseSerializer
10
+ from sagemaker .serializers import IdentitySerializer
11
11
from tests .unit .sagemaker .jumpstart .utils import (
12
12
get_special_model_spec ,
13
13
)
@@ -28,7 +28,7 @@ def test_list_jumpstart_scripts(
28
28
predictor = JumpStartPredictor (endpoint_name = "blah" , model_id = model_id )
29
29
30
30
assert predictor .content_type == MIMEType .X_TEXT
31
- assert predictor .serializer == SimpleBaseSerializer
31
+ assert isinstance ( predictor .serializer , IdentitySerializer )
32
32
33
- assert predictor .deserializer == JSONDeserializer
33
+ assert isinstance ( predictor .deserializer , JSONDeserializer )
34
34
assert predictor .accept == MIMEType .JSON
You can’t perform that action at this time.
0 commit comments