Skip to content

Commit 17ae5d9

Browse files
author
Jonathan Makunga
committed
Add Integ tests
1 parent 6ad7982 commit 17ae5d9

File tree

5 files changed

+84
-27
lines changed

5 files changed

+84
-27
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sagemaker.predictor import Predictor
3939
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4040
from sagemaker.serve.spec.inference_spec import InferenceSpec
41+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4142
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
4243
from sagemaker.serve.detector.image_detector import (
4344
auto_detect_container,
@@ -609,20 +610,17 @@ def build(
609610
if self._is_jumpstart_model_id():
610611
return self._build_for_jumpstart()
611612
if self._is_djl(): # pylint: disable=R1705
612-
if self.schema_builder is None:
613-
self._schema_builder_init("text-generation")
614-
615613
return self._build_for_djl()
616614
else:
617615
hf_model_md = get_huggingface_model_metadata(
618616
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
619617
)
620618

621-
hf_task = hf_model_md.get("pipeline_tag")
619+
model_task = hf_model_md.get("pipeline_tag")
622620
if self.schema_builder is None:
623-
self._schema_builder_init(hf_task)
621+
self._schema_builder_init(model_task)
624622

625-
if hf_task == "text-generation": # pylint: disable=R1705
623+
if model_task == "text-generation": # pylint: disable=R1705
626624
return self._build_for_tgi()
627625
else:
628626
return self._build_for_transformers()
@@ -682,17 +680,16 @@ def validate(self, model_dir: str) -> Type[bool]:
682680
return get_metadata(model_dir)
683681

684682
def _schema_builder_init(self, model_task: str):
685-
"""Initialize the"""
686-
sample_inputs, sample_outputs = None, None
683+
"""Initialize the schema builder
687684
685+
Args:
686+
model_task (str): Required, the task name
687+
688+
Raises:
689+
TaskNotFoundException: If the I/O schema for the given task is not found.
690+
"""
688691
try:
689692
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
690-
except ValueError:
691-
# TODO: try to retrieve schemas remotely
692-
pass
693-
694-
if sample_inputs and sample_outputs:
695693
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
696-
else:
697-
# TODO: Raise ClientError
698-
pass
694+
except ValueError:
695+
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")

src/sagemaker/serve/utils/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):
6060

6161
def __init__(self, message):
6262
super().__init__(message=message)
63+
64+
65+
class TaskNotFoundException(ModelBuilderException):
66+
"""Raise when task could not be found"""
67+
68+
fmt = "Error Message: {message}"
69+
70+
def __init__(self, message):
71+
super().__init__(message=message)

src/sagemaker/task.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import json
1717
import os
18-
from enum import Enum
1918
from typing import Any, Tuple
2019

2120

@@ -32,20 +31,19 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
3231
Raises:
3332
ValueError: If no tasks config found or the task does not exist in the local config.
3433
"""
35-
task_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
34+
task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
3635
try:
37-
with open(task_path) as f:
38-
task_config = json.load(f)
39-
task_schema = task_config.get(task, None)
36+
with open(task_io_config_path) as f:
37+
task_io_config = json.load(f)
38+
task_io_schemas = task_io_config.get(task, None)
4039

41-
if task_schema is None:
42-
raise ValueError(f"Could not find {task} task schema.")
40+
if task_io_schemas is None:
41+
raise ValueError(f"Could not find {task} I/O schema.")
4342

4443
sample_schema = (
45-
task_schema["inputs"]["properties"],
46-
task_schema["outputs"]["properties"],
44+
task_io_schemas["inputs"]["properties"],
45+
task_io_schemas["outputs"]["properties"],
4746
)
4847
return sample_schema
49-
5048
except FileNotFoundError:
5149
raise ValueError("Could not find tasks config file.")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import task
16+
from sagemaker.serve.builder.model_builder import ModelBuilder
17+
18+
import logging
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
24+
model_builder = ModelBuilder(model="bert-base-uncased")
25+
26+
model = model_builder.build(sagemaker_session=sagemaker_session)
27+
28+
assert model is not None
29+
assert model_builder.schema_builder is not None
30+
31+
inputs, outputs = task.retrieve_local_schemas("fill-mask")
32+
assert model_builder.schema_builder.sample_input == inputs
33+
assert model_builder.schema_builder.sample_output == outputs
34+
35+
36+
def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session):
37+
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")
38+
39+
model = model_builder.build(sagemaker_session=sagemaker_session)
40+
41+
assert model is not None
42+
assert model_builder.schema_builder is not None
43+
44+
inputs, outputs = task.retrieve_local_schemas("question-answering")
45+
assert model_builder.schema_builder.sample_input == inputs
46+
assert model_builder.schema_builder.sample_output == outputs

tests/unit/sagemaker/test_task.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def test_retrieve_local_schemas_success():
2626
assert outputs == EXPECTED_OUTPUTS
2727

2828

29+
def test_retrieve_local_schemas_text_generation_success():
30+
inputs, outputs = task.retrieve_local_schemas("text-generation")
31+
32+
assert inputs is not None
33+
assert outputs is not None
34+
35+
2936
def test_retrieve_local_schemas_throws():
3037
with pytest.raises(ValueError):
31-
task.retrieve_local_schemas("invalid-task")
38+
task.retrieve_local_schemas("not-present-task")

0 commit comments

Comments
 (0)