Skip to content

Commit 98dd0a9

Browse files
author
Jonathan Makunga
committed
address PR comments
1 parent 17ae5d9 commit 98dd0a9

File tree

6 files changed

+88
-32
lines changed

6 files changed

+88
-32
lines changed

src/sagemaker/image_uri_config/tasks.json

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
{
2-
"description": "Sample Task Inputs and Outputs",
32
"fill-mask": {
4-
"ref": "https://huggingface.co/tasks/fill-mask",
5-
"inputs": {
6-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/input.json",
3+
"sample_inputs": {
74
"properties": {
85
"inputs": "Paris is the <mask> of France.",
96
"parameters": {}
107
}
118
},
12-
"outputs": {
13-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/output.json",
9+
"sample_outputs": {
1410
"properties": [
1511
{
1612
"sequence": "Paris is the capital of France.",
@@ -20,16 +16,13 @@
2016
}
2117
},
2218
"question-answering": {
23-
"ref": "https://huggingface.co/tasks/question-answering",
24-
"inputs": {
25-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/input.json",
19+
"sample_inputs": {
2620
"properties": {
2721
"context": "I have a German Shepherd dog, named Coco.",
2822
"question": "What is my dog's breed?"
2923
}
3024
},
31-
"outputs": {
32-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/output.json",
25+
"sample_outputs": {
3326
"properties": [
3427
{
3528
"answer": "German Shepherd",
@@ -41,16 +34,13 @@
4134
}
4235
},
4336
"text-classification": {
44-
"ref": "https://huggingface.co/tasks/text-classification",
45-
"inputs": {
46-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/input.json",
37+
"sample_inputs": {
4738
"properties": {
4839
"inputs": "Where is the capital of France?, Paris is the capital of France.",
4940
"parameters": {}
5041
}
5142
},
52-
"outputs": {
53-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/output.json",
43+
"sample_outputs": {
5444
"properties": [
5545
{
5646
"label": "entailment",
@@ -60,16 +50,13 @@
6050
}
6151
},
6252
"text-generation": {
63-
"ref": "https://huggingface.co/tasks/text-generation",
64-
"inputs": {
65-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/input.json",
53+
"sample_inputs": {
6654
"properties": {
6755
"inputs": "Hello, I'm a language model",
6856
"parameters": {}
6957
}
7058
},
71-
"outputs": {
72-
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/output.json",
59+
"sample_outputs": {
7360
"properties": [
7461
{
7562
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pathlib import Path
2222

23-
from sagemaker import Session, task
23+
from sagemaker import Session
2424
from sagemaker.model import Model
2525
from sagemaker.base_predictor import PredictorBase
2626
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
@@ -38,6 +38,7 @@
3838
from sagemaker.predictor import Predictor
3939
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4040
from sagemaker.serve.spec.inference_spec import InferenceSpec
41+
from sagemaker.serve.utils import task
4142
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4243
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
4344
from sagemaker.serve.detector.image_detector import (
@@ -617,7 +618,7 @@ def build(
617618
)
618619

619620
model_task = hf_model_md.get("pipeline_tag")
620-
if self.schema_builder is None:
621+
if self.schema_builder is None and model_task:
621622
self._schema_builder_init(model_task)
622623

623624
if model_task == "text-generation": # pylint: disable=R1705

src/sagemaker/serve/utils/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, message):
6363

6464

6565
class TaskNotFoundException(ModelBuilderException):
66-
"""Raise when task could not be found"""
66+
"""Raise when HuggingFace task could not be found"""
6767

6868
fmt = "Error Message: {message}"
6969

src/sagemaker/task.py renamed to src/sagemaker/serve/utils/task.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
3131
Raises:
3232
ValueError: If no tasks config found or the task does not exist in the local config.
3333
"""
34-
task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
34+
# task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
35+
c = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
36+
task_io_config_path = os.path.join(c, "image_uri_config", "tasks.json")
3537
try:
3638
with open(task_io_config_path) as f:
3739
task_io_config = json.load(f)
@@ -41,8 +43,8 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
4143
raise ValueError(f"Could not find {task} I/O schema.")
4244

4345
sample_schema = (
44-
task_io_schemas["inputs"]["properties"],
45-
task_io_schemas["outputs"]["properties"],
46+
task_io_schemas["sample_inputs"]["properties"],
47+
task_io_schemas["sample_outputs"]["properties"],
4648
)
4749
return sample_schema
4850
except FileNotFoundError:

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
from sagemaker import task
1615
from sagemaker.serve.builder.model_builder import ModelBuilder
16+
from sagemaker.serve.utils import task
17+
18+
import pytest
19+
20+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
21+
from tests.integ.sagemaker.serve.constants import (
22+
PYTHON_VERSION_IS_NOT_310,
23+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
24+
)
25+
26+
from tests.integ.timeout import timeout
27+
from tests.integ.utils import cleanup_model_resources
1728

1829
import logging
1930

@@ -33,7 +44,13 @@ def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session
3344
assert model_builder.schema_builder.sample_output == outputs
3445

3546

36-
def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session):
47+
@pytest.mark.skipif(
48+
PYTHON_VERSION_IS_NOT_310,
49+
reason="Testing Schema Builder Simplification feature",
50+
)
51+
def test_model_builder_happy_path_with_only_model_id_question_answering(
52+
sagemaker_session, gpu_instance_type
53+
):
3754
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")
3855

3956
model = model_builder.build(sagemaker_session=sagemaker_session)
@@ -44,3 +61,41 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(sagemake
4461
inputs, outputs = task.retrieve_local_schemas("question-answering")
4562
assert model_builder.schema_builder.sample_input == inputs
4663
assert model_builder.schema_builder.sample_output == outputs
64+
65+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
66+
caught_ex = None
67+
try:
68+
iam_client = sagemaker_session.boto_session.client("iam")
69+
role_arn = iam_client.get_role(RoleName="JarvisTest")["Role"]["Arn"]
70+
71+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
72+
predictor = model.deploy(
73+
role=role_arn, instance_count=1, instance_type=gpu_instance_type
74+
)
75+
76+
predicted_outputs = predictor.predict(inputs)
77+
assert predicted_outputs is not None
78+
79+
except Exception as e:
80+
caught_ex = e
81+
finally:
82+
cleanup_model_resources(
83+
sagemaker_session=model_builder.sagemaker_session,
84+
model_name=model.name,
85+
endpoint_name=model.endpoint_name,
86+
)
87+
if caught_ex:
88+
logger.exception(caught_ex)
89+
assert (
90+
False
91+
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
92+
93+
94+
def test_model_builder_negative_path(sagemaker_session):
95+
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
96+
97+
with pytest.raises(
98+
TaskNotFoundException,
99+
match="Error Message: Schema builder for text-to-image could not be found.",
100+
):
101+
model_builder.build(sagemaker_session=sagemaker_session)

tests/unit/sagemaker/test_task.py renamed to tests/unit/sagemaker/serve/utils/test_task.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from unittest.mock import patch
16+
1517
import pytest
16-
from sagemaker import task
18+
19+
from sagemaker.serve.utils import task
1720

1821
EXPECTED_INPUTS = {"inputs": "Paris is the <mask> of France.", "parameters": {}}
1922
EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}]
23+
HF_INVALID_TASK = "not-present-task"
2024

2125

2226
def test_retrieve_local_schemas_success():
@@ -34,5 +38,12 @@ def test_retrieve_local_schemas_text_generation_success():
3438

3539

3640
def test_retrieve_local_schemas_throws():
37-
with pytest.raises(ValueError):
38-
task.retrieve_local_schemas("not-present-task")
41+
with pytest.raises(ValueError, match=f"Could not find {HF_INVALID_TASK} I/O schema."):
42+
task.retrieve_local_schemas(HF_INVALID_TASK)
43+
44+
45+
@patch("builtins.open")
46+
def test_retrieve_local_schemas_file_not_found(mock_open):
47+
mock_open.side_effect = FileNotFoundError
48+
with pytest.raises(ValueError, match="Could not find tasks config file."):
49+
task.retrieve_local_schemas(HF_INVALID_TASK)

0 commit comments

Comments
 (0)