Skip to content

Commit 1f60714

Browse files
authored
Merge branch 'master' into master
2 parents e450813 + 615a8ad commit 1f60714

File tree

14 files changed

+292
-47
lines changed

14 files changed

+292
-47
lines changed

src/sagemaker/huggingface/llm_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
8181
Returns:
8282
dict: The model metadata retrieved with the HuggingFace API
8383
"""
84-
84+
if not model_id:
85+
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
8586
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
8687
hf_model_metadata_json = None
8788
try:

src/sagemaker/model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
766766

767767
def _script_mode_env_vars(self):
768768
"""Returns a mapping of environment variables for script mode execution"""
769-
script_name = None
770-
dir_name = None
769+
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
770+
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
771771
if self.uploaded_code:
772772
script_name = self.uploaded_code.script_name
773773
if self.repacked_model_data or self.enable_network_isolation():
@@ -783,8 +783,8 @@ def _script_mode_env_vars(self):
783783
else "file://" + self.source_dir
784784
)
785785
return {
786-
SCRIPT_PARAM_NAME.upper(): script_name or str(),
787-
DIR_PARAM_NAME.upper(): dir_name or str(),
786+
SCRIPT_PARAM_NAME.upper(): script_name,
787+
DIR_PARAM_NAME.upper(): dir_name,
788788
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
789789
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
790790
}

src/sagemaker/serve/builder/model_builder.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
124124
into a stream. All translations between the server and the client are handled
125125
automatically with the specified input and output.
126126
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
127-
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
128-
``inference_spec`` is required for the model builder to build the artifact.
127+
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
128+
is required for the model builder to build the artifact.
129129
inference_spec (InferenceSpec): The inference spec file with your customized
130130
``invoke`` and ``load`` functions.
131131
image_uri (Optional[str]): The container image uri (which is derived from a
@@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
145145
to the model server). Possible values for this argument are
146146
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
147147
``TRITON``, and``TGI``.
148+
model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace
149+
model metadata. Currently ``HF_TASK`` is overridable.
148150
"""
149151

150152
model_path: Optional[str] = field(
@@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
241243
model_server: Optional[ModelServer] = field(
242244
default=None, metadata={"help": "Define the model server to deploy to."}
243245
)
246+
model_metadata: Optional[Dict[str, Any]] = field(
247+
default=None,
248+
metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"},
249+
)
244250

245251
def _build_validations(self):
246252
"""Placeholder docstring"""
@@ -616,6 +622,9 @@ def build( # pylint: disable=R0911
616622
self._is_custom_image_uri = self.image_uri is not None
617623

618624
if isinstance(self.model, str):
625+
model_task = None
626+
if self.model_metadata:
627+
model_task = self.model_metadata.get("HF_TASK")
619628
if self._is_jumpstart_model_id():
620629
return self._build_for_jumpstart()
621630
if self._is_djl(): # pylint: disable=R1705
@@ -625,10 +634,10 @@ def build( # pylint: disable=R0911
625634
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
626635
)
627636

628-
model_task = hf_model_md.get("pipeline_tag")
629-
if self.schema_builder is None and model_task:
637+
if model_task is None:
638+
model_task = hf_model_md.get("pipeline_tag")
639+
if self.schema_builder is None and model_task is not None:
630640
self._schema_builder_init(model_task)
631-
632641
if model_task == "text-generation": # pylint: disable=R1705
633642
return self._build_for_tgi()
634643
elif self._can_fit_on_single_gpu():

src/sagemaker/serve/schema/task.json

+20-20
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
{
22
"fill-mask": {
3-
"sample_inputs": {
3+
"sample_inputs": {
44
"properties": {
55
"inputs": "Paris is the [MASK] of France.",
66
"parameters": {}
77
}
8-
},
9-
"sample_outputs": {
8+
},
9+
"sample_outputs": {
1010
"properties": [
1111
{
1212
"sequence": "Paris is the capital of France.",
1313
"score": 0.7
1414
}
1515
]
1616
}
17-
},
17+
},
1818
"question-answering": {
19-
"sample_inputs": {
19+
"sample_inputs": {
2020
"properties": {
2121
"context": "I have a German Shepherd dog, named Coco.",
2222
"question": "What is my dog's breed?"
2323
}
24-
},
25-
"sample_outputs": {
24+
},
25+
"sample_outputs": {
2626
"properties": [
2727
{
2828
"answer": "German Shepherd",
@@ -32,36 +32,36 @@
3232
}
3333
]
3434
}
35-
},
35+
},
3636
"text-classification": {
37-
"sample_inputs": {
37+
"sample_inputs": {
3838
"properties": {
3939
"inputs": "Where is the capital of France?, Paris is the capital of France.",
4040
"parameters": {}
4141
}
42-
},
43-
"sample_outputs": {
42+
},
43+
"sample_outputs": {
4444
"properties": [
4545
{
4646
"label": "entailment",
4747
"score": 0.997
4848
}
4949
]
5050
}
51-
},
52-
"text-generation": {
53-
"sample_inputs": {
51+
},
52+
"text-generation": {
53+
"sample_inputs": {
5454
"properties": {
5555
"inputs": "Hello, I'm a language model",
5656
"parameters": {}
5757
}
58-
},
59-
"sample_outputs": {
58+
},
59+
"sample_outputs": {
6060
"properties": [
61-
{
62-
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63-
}
61+
{
62+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63+
}
6464
]
6565
}
66-
}
66+
}
6767
}

tests/integ/sagemaker/serve/test_schema_builder.py

+66
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,69 @@ def test_model_builder_negative_path(sagemaker_session):
9999
match="Error Message: Schema builder for text-to-image could not be found.",
100100
):
101101
model_builder.build(sagemaker_session=sagemaker_session)
102+
103+
104+
@pytest.mark.skipif(
105+
PYTHON_VERSION_IS_NOT_310,
106+
reason="Testing Schema Builder Simplification feature",
107+
)
108+
@pytest.mark.parametrize(
109+
"model_id, task_provided",
110+
[
111+
("bert-base-uncased", "fill-mask"),
112+
("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"),
113+
],
114+
)
115+
def test_model_builder_happy_path_with_task_provided(
116+
model_id, task_provided, sagemaker_session, gpu_instance_type
117+
):
118+
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided})
119+
120+
model = model_builder.build(sagemaker_session=sagemaker_session)
121+
122+
assert model is not None
123+
assert model_builder.schema_builder is not None
124+
125+
inputs, outputs = task.retrieve_local_schemas(task_provided)
126+
assert model_builder.schema_builder.sample_input == inputs
127+
assert model_builder.schema_builder.sample_output == outputs
128+
129+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
130+
caught_ex = None
131+
try:
132+
iam_client = sagemaker_session.boto_session.client("iam")
133+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
134+
135+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
136+
predictor = model.deploy(
137+
role=role_arn, instance_count=1, instance_type=gpu_instance_type
138+
)
139+
140+
predicted_outputs = predictor.predict(inputs)
141+
assert predicted_outputs is not None
142+
143+
except Exception as e:
144+
caught_ex = e
145+
finally:
146+
cleanup_model_resources(
147+
sagemaker_session=model_builder.sagemaker_session,
148+
model_name=model.name,
149+
endpoint_name=model.endpoint_name,
150+
)
151+
if caught_ex:
152+
logger.exception(caught_ex)
153+
assert (
154+
False
155+
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
156+
157+
158+
def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
159+
model_builder = ModelBuilder(
160+
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}
161+
)
162+
163+
with pytest.raises(
164+
TaskNotFoundException,
165+
match="Error Message: Schema builder for invalid-task could not be found.",
166+
):
167+
model_builder.build(sagemaker_session=sagemaker_session)

tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py

+14
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@
5252
"some-other-key": {"some-key": "some-value"},
5353
}
5454

55+
DATA_SOURCE_UNIQUE_ID_TOO_LONG = """
56+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
57+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
58+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
59+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
60+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
61+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
62+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
63+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
64+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
65+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
66+
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
67+
"""
68+
5569
DESCRIBE_FEATURE_GROUP_RESPONSE = {
5670
"FeatureGroupArn": INPUT_FEATURE_GROUP_ARN,
5771
"FeatureGroupName": INPUT_FEATURE_GROUP_NAME,

tests/unit/sagemaker/feature_store/feature_processor/test_validation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import pytest
2020

2121
import test_data_helpers as tdh
22-
import string
23-
import random
2422
from mock import Mock
2523

2624
from sagemaker.feature_store.feature_processor._validation import (
@@ -164,7 +162,7 @@ def invalid_spark_position(spark, fg_data_source, s3_data_source):
164162
("", "unique_id", "data_source_name of input does not match pattern '.*'."),
165163
(
166164
"source",
167-
"".join(random.choices(string.ascii_uppercase, k=2050)),
165+
tdh.DATA_SOURCE_UNIQUE_ID_TOO_LONG,
168166
"data_source_unique_id of input does not match pattern '.*'.",
169167
),
170168
("source", "", "data_source_unique_id of input does not match pattern '.*'."),

tests/unit/sagemaker/remote_function/core/test_stored_function.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939

4040
from sagemaker.workflow.function_step import _FunctionStep, DelayedReturn
4141
from sagemaker.workflow.parameters import ParameterFloat
42-
from sagemaker.utils import sagemaker_timestamp
4342

4443
from tests.unit.sagemaker.experiments.helpers import (
4544
TEST_EXP_DISPLAY_NAME,
@@ -55,7 +54,7 @@
5554
FUNCTION_FOLDER = "function"
5655
ARGUMENT_FOLDER = "arguments"
5756
RESULT_FOLDER = "results"
58-
PIPELINE_BUILD_TIME = sagemaker_timestamp()
57+
PIPELINE_BUILD_TIME = "2022-05-10T17:30:20Z"
5958

6059
mock_s3 = {}
6160

0 commit comments

Comments
 (0)