Skip to content

Commit fd3e86b

Browse files
committed
feat: Add Optional task to Model
1 parent 9091b21 commit fd3e86b

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

src/sagemaker/model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
dependencies: Optional[List[str]] = None,
157157
git_config: Optional[Dict[str, str]] = None,
158158
resources: Optional[ResourceRequirements] = None,
159+
task: Optional[Union[str, PipelineVariable]] = None,
159160
):
160161
"""Initialize an SageMaker ``Model``.
161162
@@ -319,7 +320,9 @@ def __init__(
319320
for a model to be deployed to an endpoint. Only
320321
EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
321322
(Default: None).
322-
323+
task (str or PipelineVariable): Task values used to override the HuggingFace task
324+
Examples are: "audio-classification", "depth-estimation",
325+
"feature-extraction" etc. (default: None).
323326
"""
324327
self.model_data = model_data
325328
self.image_uri = image_uri
@@ -396,6 +399,7 @@ def __init__(
396399
self.content_types = None
397400
self.response_types = None
398401
self.accept_eula = None
402+
self.task = task
399403

400404
@runnable_by_pipeline
401405
def register(

src/sagemaker/serve/builder/model_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
203203
"help": (
204204
'Model object with "predict" method to perform inference '
205205
"or HuggingFace/JumpStart Model ID"
206+
"or HuggingFace Task to override"
206207
)
207208
},
208209
)

tests/integ/sagemaker/serve/test_serve_transformers.py

+56
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16+
from sagemaker.model import Model
1617
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1718
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
1819

@@ -28,6 +29,11 @@
2829

2930
logger = logging.getLogger(__name__)
3031

32+
MODEL_DATA = "s3://bucket/model.tar.gz"
33+
MODEL_IMAGE = "mi"
34+
ROLE = "some-role"
35+
HF_TASK = "fill-mask"
36+
3137
sample_input = {
3238
"inputs": "The man worked as a [MASK].",
3339
}
@@ -80,6 +86,17 @@ def model_builder_model_schema_builder():
8086
)
8187

8288

89+
@pytest.fixture
90+
def model_builder_model_with_task_builder():
91+
model = Model(
92+
MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name="bert-base-uncased", role=ROLE
93+
)
94+
return ModelBuilder(
95+
model_path=HF_DIR,
96+
model=model,
97+
)
98+
99+
83100
@pytest.fixture
84101
def model_builder(request):
85102
return request.getfixturevalue(request.param)
@@ -122,3 +139,42 @@ def test_pytorch_transformers_sagemaker_endpoint(
122139
assert (
123140
False
124141
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"
142+
143+
144+
@pytest.mark.skipif(
145+
PYTHON_VERSION_IS_NOT_310,
146+
reason="Testing Optional task",
147+
)
148+
@pytest.mark.parametrize("model_builder", ["model_builder_model_with_task_builder"], indirect=True)
149+
def test_happy_path_with_task_sagemaker_endpoint(
150+
sagemaker_session, model_builder, gpu_instance_type, input
151+
):
152+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
153+
caught_ex = None
154+
155+
iam_client = sagemaker_session.boto_session.client("iam")
156+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
157+
158+
model = model_builder.build(
159+
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
160+
)
161+
162+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
163+
try:
164+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
165+
predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1)
166+
logger.info("Endpoint successfully deployed.")
167+
predictor.predict(input)
168+
except Exception as e:
169+
caught_ex = e
170+
finally:
171+
cleanup_model_resources(
172+
sagemaker_session=model_builder.sagemaker_session,
173+
model_name=model.name,
174+
endpoint_name=model.endpoint_name,
175+
)
176+
if caught_ex:
177+
logger.exception(caught_ex)
178+
assert (
179+
False
180+
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"

tests/unit/sagemaker/model/test_deploy.py

+55
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
},
8686
limits={},
8787
)
88+
HF_TASK = "audio-classification"
8889

8990

9091
@pytest.fixture
@@ -1027,3 +1028,57 @@ def test_deploy_with_name_and_resources(sagemaker_session):
10271028
async_inference_config_dict=None,
10281029
live_logging=False,
10291030
)
1031+
1032+
1033+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
1034+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
1035+
def test_deploy_with_name_and_task(sagemaker_session):
1036+
sagemaker_session.sagemaker_config = {}
1037+
1038+
model = Model(
1039+
MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
1040+
)
1041+
1042+
endpoint_name = "testing-task-input"
1043+
predictor = model.deploy(
1044+
endpoint_name=endpoint_name,
1045+
instance_type=INSTANCE_TYPE,
1046+
initial_instance_count=INSTANCE_COUNT,
1047+
)
1048+
1049+
sagemaker_session.create_model.assert_called_with(
1050+
name=MODEL_IMAGE,
1051+
role=ROLE,
1052+
task=HF_TASK
1053+
)
1054+
1055+
assert isinstance(predictor, sagemaker.predictor.Predictor)
1056+
assert predictor.endpoint_name == endpoint_name
1057+
assert predictor.sagemaker_session == sagemaker_session
1058+
1059+
1060+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
1061+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
1062+
def test_deploy_with_name_and_without_task(sagemaker_session):
1063+
sagemaker_session.sagemaker_config = {}
1064+
1065+
model = Model(
1066+
MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
1067+
)
1068+
1069+
endpoint_name = "testing-without-task-input"
1070+
predictor = model.deploy(
1071+
endpoint_name=endpoint_name,
1072+
instance_type=INSTANCE_TYPE,
1073+
initial_instance_count=INSTANCE_COUNT,
1074+
)
1075+
1076+
sagemaker_session.create_model.assert_called_with(
1077+
name=MODEL_IMAGE,
1078+
role=ROLE,
1079+
task=None,
1080+
)
1081+
1082+
assert isinstance(predictor, sagemaker.predictor.Predictor)
1083+
assert predictor.endpoint_name == endpoint_name
1084+
assert predictor.sagemaker_session == sagemaker_session

0 commit comments

Comments
 (0)