Skip to content

feature: Add overriding logic in ModelBuilder when task is provided #4460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 8, 2024
Merged
3 changes: 2 additions & 1 deletion src/sagemaker/huggingface/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
Returns:
dict: The model metadata retrieved with the HuggingFace API
"""

if not model_id:
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
Expand Down
14 changes: 10 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
into a stream. All translations between the server and the client are handled
automatically with the specified input and output.
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
inference) or a HuggingFace/JumpStart Model ID (followed by ``:task`` if you need
to override the task, e.g. bert-base-uncased:fill-mask). Either ``model`` or
``inference_spec`` is required for the model builder to build the artifact.
inference_spec (InferenceSpec): The inference spec file with your customized
``invoke`` and ``load`` functions.
Expand Down Expand Up @@ -205,6 +206,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
"help": (
'Model object with "predict" method to perform inference '
"or HuggingFace/JumpStart Model ID"
"or if you need to override task, provide input as ModelID:Task"
)
},
)
Expand Down Expand Up @@ -610,6 +612,10 @@ def build(
self._is_custom_image_uri = self.image_uri is not None

if isinstance(self.model, str):
model_task = None
if ":" in self.model:
model_task = self.model.split(":")[1]
self.model = self.model.split(":")[0]
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl(): # pylint: disable=R1705
Expand All @@ -619,10 +625,10 @@ def build(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task:
if model_task is None:
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._schema_builder_init(model_task)

if model_task == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
else:
Expand Down
42 changes: 21 additions & 21 deletions src/sagemaker/serve/schema/task.json
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
{
"fill-mask": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Paris is the <mask> of France.",
"inputs": "Paris is the [MASK] of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"sequence": "Paris is the capital of France.",
"score": 0.7
}
]
}
},
},
"question-answering": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"context": "I have a German Shepherd dog, named Coco.",
"question": "What is my dog's breed?"
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"answer": "German Shepherd",
Expand All @@ -32,36 +32,36 @@
}
]
}
},
},
"text-classification": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Where is the capital of France?, Paris is the capital of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"label": "entailment",
"score": 0.997
}
]
}
},
"text-generation": {
"sample_inputs": {
},
"text-generation": {
"sample_inputs": {
"properties": {
"inputs": "Hello, I'm a language model",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
]
}
}
}
}
64 changes: 64 additions & 0 deletions tests/integ/sagemaker/serve/test_schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,67 @@ def test_model_builder_negative_path(sagemaker_session):
match="Error Message: Schema builder for text-to-image could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing Schema Builder Simplification feature",
)
@pytest.mark.parametrize(
"model_id, task_provided",
[
("bert-base-uncased", "fill-mask"),
("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"),
],
)
def test_model_builder_happy_path_with_task_provided(
model_id, task_provided, sagemaker_session, gpu_instance_type
):
model_builder = ModelBuilder(model=f"{model_id}:{task_provided}")

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas(task_provided)
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
caught_ex = None
try:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(
role=role_arn, instance_count=1, instance_type=gpu_instance_type
)

predicted_outputs = predictor.predict(inputs)
assert predicted_outputs is not None

except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert (
False
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"


def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
model_builder = ModelBuilder(model="bert-base-uncased:invalid-task")

with pytest.raises(
TaskNotFoundException,
match="Error Message: Schema builder for invalid-task could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)
118 changes: 117 additions & 1 deletion tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,124 @@ def test_build_negative_path_when_schema_builder_not_present(

model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")

self.assertRaisesRegexp(
self.assertRaisesRegex(
TaskNotFoundException,
"Error Message: Schema builder for text-to-image could not be found.",
lambda: model_builder.build(sagemaker_session=mock_session),
)

@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_happy_path_override_with_task_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
mock_hf_model,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"

model_builder = ModelBuilder(model="bert-base-uncased:text-generation")
model_builder.build(sagemaker_session=mock_session)

self.assertIsNotNone(model_builder.schema_builder)
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
self.assertEqual(
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
)
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)

@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_task_override_with_invalid_task_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"
model_ids_with_invalid_task = ["bert-base-uncased:invalid-task", "bert-base-uncased:"]
for model_id in model_ids_with_invalid_task:
model_builder = ModelBuilder(model=model_id)

provided_task = model_id.split(":")[1]
self.assertRaisesRegex(
TaskNotFoundException,
f"Error Message: Schema builder for {provided_task} could not be found.",
lambda: model_builder.build(sagemaker_session=mock_session),
)

@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_task_override_with_invalid_model_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_image_uris_retrieve,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError

mock_image_uris_retrieve.return_value = "https://some-image-uri"
invalid_model_ids_with_task = [":fill-mask", "bert-base-uncased;fill-mask"]

for model_id in invalid_model_ids_with_task:
model_builder = ModelBuilder(model=model_id)
with self.assertRaises(Exception):
model_builder.build(sagemaker_session=mock_session)
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/serve/utils/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from sagemaker.serve.utils import task

EXPECTED_INPUTS = {"inputs": "Paris is the <mask> of France.", "parameters": {}}
EXPECTED_INPUTS = {"inputs": "Paris is the [MASK] of France.", "parameters": {}}
EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}]
HF_INVALID_TASK = "not-present-task"

Expand Down