18
18
import pytest
19
19
20
20
from sagemaker .serve .utils .exceptions import TaskNotFoundException
21
+ from sagemaker_schema_inference_artifacts .huggingface import remote_schema_retriever
21
22
from tests .integ .sagemaker .serve .constants import (
22
23
PYTHON_VERSION_IS_NOT_310 ,
23
24
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT ,
31
32
logger = logging .getLogger (__name__ )
32
33
33
34
34
- def test_model_builder_happy_path_with_only_model_id_fill_mask (sagemaker_session ):
35
- model_builder = ModelBuilder (model = "bert-base-uncased " )
35
+ def test_model_builder_happy_path_with_only_model_id_text_generation (sagemaker_session ):
36
+ model_builder = ModelBuilder (model = "HuggingFaceH4/zephyr-7b-beta " )
36
37
37
38
model = model_builder .build (sagemaker_session = sagemaker_session )
38
39
39
40
assert model is not None
40
41
assert model_builder .schema_builder is not None
41
42
42
- inputs , outputs = task .retrieve_local_schemas ("fill-mask " )
43
- assert model_builder .schema_builder .sample_input == inputs
43
+ inputs , outputs = task .retrieve_local_schemas ("text-generation " )
44
+ assert model_builder .schema_builder .sample_input [ "inputs" ] == inputs [ "inputs" ]
44
45
assert model_builder .schema_builder .sample_output == outputs
45
46
46
47
48
+ def test_model_builder_negative_path (sagemaker_session ):
49
+ # A model-task combo unsupported by both the local and remote schema fallback options. (eg: text-to-video)
50
+ model_builder = ModelBuilder (model = "ByteDance/AnimateDiff-Lightning" )
51
+ with pytest .raises (
52
+ TaskNotFoundException ,
53
+ match = "Error Message: HuggingFace Schema builder samples for text-to-video could not be found locally or "
54
+ "via remote." ,
55
+ ):
56
+ model_builder .build (sagemaker_session = sagemaker_session )
57
+
58
+
47
59
@pytest .mark .skipif (
48
60
PYTHON_VERSION_IS_NOT_310 ,
49
- reason = "Testing Schema Builder Simplification feature" ,
61
+ reason = "Testing Schema Builder Simplification feature - Local Schema " ,
50
62
)
51
- def test_model_builder_happy_path_with_only_model_id_question_answering (
52
- sagemaker_session , gpu_instance_type
63
+ @pytest .mark .parametrize (
64
+ "model_id, task_provided, instance_type_provided, container_startup_timeout" ,
65
+ [
66
+ (
67
+ "distilbert/distilbert-base-uncased-finetuned-sst-2-english" ,
68
+ "text-classification" ,
69
+ "ml.m5.xlarge" ,
70
+ None ,
71
+ ),
72
+ (
73
+ "cardiffnlp/twitter-roberta-base-sentiment-latest" ,
74
+ "text-classification" ,
75
+ "ml.m5.xlarge" ,
76
+ None ,
77
+ ),
78
+ ("HuggingFaceH4/zephyr-7b-beta" , "text-generation" , "ml.g5.2xlarge" , 900 ),
79
+ ("HuggingFaceH4/zephyr-7b-alpha" , "text-generation" , "ml.g5.2xlarge" , 900 ),
80
+ ],
81
+ )
82
+ def test_model_builder_happy_path_with_task_provided_local_schema_mode (
83
+ model_id , task_provided , sagemaker_session , instance_type_provided , container_startup_timeout
53
84
):
54
- model_builder = ModelBuilder (model = "bert-large-uncased-whole-word-masking-finetuned-squad" )
85
+ model_builder = ModelBuilder (
86
+ model = model_id ,
87
+ model_metadata = {"HF_TASK" : task_provided },
88
+ instance_type = instance_type_provided ,
89
+ )
55
90
56
91
model = model_builder .build (sagemaker_session = sagemaker_session )
57
92
58
93
assert model is not None
59
94
assert model_builder .schema_builder is not None
60
95
61
- inputs , outputs = task .retrieve_local_schemas ("question-answering" )
62
- assert model_builder .schema_builder .sample_input == inputs
96
+ inputs , outputs = task .retrieve_local_schemas (task_provided )
97
+ if task_provided == "text-generation" :
98
+ # ignore 'tokens' and other metadata in this case
99
+ assert model_builder .schema_builder .sample_input ["inputs" ] == inputs ["inputs" ]
100
+ else :
101
+ assert model_builder .schema_builder .sample_input == inputs
63
102
assert model_builder .schema_builder .sample_output == outputs
64
103
65
104
with timeout (minutes = SERVE_SAGEMAKER_ENDPOINT_TIMEOUT ):
@@ -69,9 +108,17 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
69
108
role_arn = iam_client .get_role (RoleName = "SageMakerRole" )["Role" ]["Arn" ]
70
109
71
110
logger .info ("Deploying and predicting in SAGEMAKER_ENDPOINT mode..." )
72
- predictor = model .deploy (
73
- role = role_arn , instance_count = 1 , instance_type = gpu_instance_type
74
- )
111
+ if container_startup_timeout :
112
+ predictor = model .deploy (
113
+ role = role_arn ,
114
+ instance_count = 1 ,
115
+ instance_type = instance_type_provided ,
116
+ container_startup_health_check_timeout = container_startup_timeout ,
117
+ )
118
+ else :
119
+ predictor = model .deploy (
120
+ role = role_arn , instance_count = 1 , instance_type = instance_type_provided
121
+ )
75
122
76
123
predicted_outputs = predictor .predict (inputs )
77
124
assert predicted_outputs is not None
@@ -91,38 +138,38 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
91
138
), f"{ caught_ex } was thrown when running transformers sagemaker endpoint test"
92
139
93
140
94
- def test_model_builder_negative_path (sagemaker_session ):
95
- model_builder = ModelBuilder (model = "CompVis/stable-diffusion-v1-4" )
96
-
97
- with pytest .raises (
98
- TaskNotFoundException ,
99
- match = "Error Message: Schema builder for text-to-image could not be found." ,
100
- ):
101
- model_builder .build (sagemaker_session = sagemaker_session )
102
-
103
-
104
141
@pytest .mark .skipif (
105
142
PYTHON_VERSION_IS_NOT_310 ,
106
- reason = "Testing Schema Builder Simplification feature" ,
143
+ reason = "Testing Schema Builder Simplification feature - Remote Schema " ,
107
144
)
108
145
@pytest .mark .parametrize (
109
- "model_id, task_provided" ,
146
+ "model_id, task_provided, instance_type_provided " ,
110
147
[
111
- ("bert-base-uncased" , "fill-mask" ),
112
- ("bert-large-uncased-whole-word-masking-finetuned-squad" , "question-answering" ),
148
+ ("google-bert/bert-base-uncased" , "fill-mask" , "ml.m5.xlarge" ),
149
+ ("google-bert/bert-base-cased" , "fill-mask" , "ml.m5.xlarge" ),
150
+ (
151
+ "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad" ,
152
+ "question-answering" ,
153
+ "ml.m5.xlarge" ,
154
+ ),
155
+ ("deepset/roberta-base-squad2" , "question-answering" , "ml.m5.xlarge" ),
113
156
],
114
157
)
115
- def test_model_builder_happy_path_with_task_provided (
116
- model_id , task_provided , sagemaker_session , gpu_instance_type
158
+ def test_model_builder_happy_path_with_task_provided_remote_schema_mode (
159
+ model_id , task_provided , sagemaker_session , instance_type_provided
117
160
):
118
- model_builder = ModelBuilder (model = model_id , model_metadata = {"HF_TASK" : task_provided })
119
-
161
+ model_builder = ModelBuilder (
162
+ model = model_id ,
163
+ model_metadata = {"HF_TASK" : task_provided },
164
+ instance_type = instance_type_provided ,
165
+ )
120
166
model = model_builder .build (sagemaker_session = sagemaker_session )
121
167
122
168
assert model is not None
123
169
assert model_builder .schema_builder is not None
124
170
125
- inputs , outputs = task .retrieve_local_schemas (task_provided )
171
+ remote_hf_schema_helper = remote_schema_retriever .RemoteSchemaRetriever ()
172
+ inputs , outputs = remote_hf_schema_helper .get_resolved_hf_schema_for_task (task_provided )
126
173
assert model_builder .schema_builder .sample_input == inputs
127
174
assert model_builder .schema_builder .sample_output == outputs
128
175
@@ -134,7 +181,7 @@ def test_model_builder_happy_path_with_task_provided(
134
181
135
182
logger .info ("Deploying and predicting in SAGEMAKER_ENDPOINT mode..." )
136
183
predictor = model .deploy (
137
- role = role_arn , instance_count = 1 , instance_type = gpu_instance_type
184
+ role = role_arn , instance_count = 1 , instance_type = instance_type_provided
138
185
)
139
186
140
187
predicted_outputs = predictor .predict (inputs )
@@ -162,6 +209,7 @@ def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
162
209
163
210
with pytest .raises (
164
211
TaskNotFoundException ,
165
- match = "Error Message: Schema builder for invalid-task could not be found." ,
212
+ match = "Error Message: HuggingFace Schema builder samples for invalid-task could not be found locally or "
213
+ "via remote." ,
166
214
):
167
215
model_builder .build (sagemaker_session = sagemaker_session )
0 commit comments