Skip to content

Commit 9e4e2ec

Browse files
author
Xiong Zeng
committed
Revert "feat: Add Optional task to Model"
This reverts commit fd3e86b.
1 parent 8b8c081 commit 9e4e2ec

File tree

4 files changed

+1
-117
lines changed

4 files changed

+1
-117
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def __init__(
156156
dependencies: Optional[List[str]] = None,
157157
git_config: Optional[Dict[str, str]] = None,
158158
resources: Optional[ResourceRequirements] = None,
159-
task: Optional[Union[str, PipelineVariable]] = None,
160159
):
161160
"""Initialize an SageMaker ``Model``.
162161
@@ -320,9 +319,7 @@ def __init__(
320319
for a model to be deployed to an endpoint. Only
321320
EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
322321
(Default: None).
323-
task (str or PipelineVariable): Task values used to override the HuggingFace task
324-
Examples are: "audio-classification", "depth-estimation",
325-
"feature-extraction" etc. (default: None).
322+
326323
"""
327324
self.model_data = model_data
328325
self.image_uri = image_uri
@@ -399,7 +396,6 @@ def __init__(
399396
self.content_types = None
400397
self.response_types = None
401398
self.accept_eula = None
402-
self.task = task
403399

404400
@runnable_by_pipeline
405401
def register(

src/sagemaker/serve/builder/model_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
205205
"help": (
206206
'Model object with "predict" method to perform inference '
207207
"or HuggingFace/JumpStart Model ID"
208-
"or HuggingFace Task to override"
209208
)
210209
},
211210
)

tests/integ/sagemaker/serve/test_serve_transformers.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from sagemaker.model import Model
1716
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1817
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
1918

@@ -29,11 +28,6 @@
2928

3029
logger = logging.getLogger(__name__)
3130

32-
MODEL_DATA = "s3://bucket/model.tar.gz"
33-
MODEL_IMAGE = "mi"
34-
ROLE = "some-role"
35-
HF_TASK = "fill-mask"
36-
3731
sample_input = {
3832
"inputs": "The man worked as a [MASK].",
3933
}
@@ -86,17 +80,6 @@ def model_builder_model_schema_builder():
8680
)
8781

8882

89-
@pytest.fixture
90-
def model_builder_model_with_task_builder():
91-
model = Model(
92-
MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name="bert-base-uncased", role=ROLE
93-
)
94-
return ModelBuilder(
95-
model_path=HF_DIR,
96-
model=model,
97-
)
98-
99-
10083
@pytest.fixture
10184
def model_builder(request):
10285
return request.getfixturevalue(request.param)
@@ -139,42 +122,3 @@ def test_pytorch_transformers_sagemaker_endpoint(
139122
assert (
140123
False
141124
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"
142-
143-
144-
@pytest.mark.skipif(
145-
PYTHON_VERSION_IS_NOT_310,
146-
reason="Testing Optional task",
147-
)
148-
@pytest.mark.parametrize("model_builder", ["model_builder_model_with_task_builder"], indirect=True)
149-
def test_happy_path_with_task_sagemaker_endpoint(
150-
sagemaker_session, model_builder, gpu_instance_type, input
151-
):
152-
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
153-
caught_ex = None
154-
155-
iam_client = sagemaker_session.boto_session.client("iam")
156-
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
157-
158-
model = model_builder.build(
159-
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
160-
)
161-
162-
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
163-
try:
164-
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
165-
predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1)
166-
logger.info("Endpoint successfully deployed.")
167-
predictor.predict(input)
168-
except Exception as e:
169-
caught_ex = e
170-
finally:
171-
cleanup_model_resources(
172-
sagemaker_session=model_builder.sagemaker_session,
173-
model_name=model.name,
174-
endpoint_name=model.endpoint_name,
175-
)
176-
if caught_ex:
177-
logger.exception(caught_ex)
178-
assert (
179-
False
180-
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
},
8686
limits={},
8787
)
88-
HF_TASK = "audio-classification"
8988

9089

9190
@pytest.fixture
@@ -1028,57 +1027,3 @@ def test_deploy_with_name_and_resources(sagemaker_session):
10281027
async_inference_config_dict=None,
10291028
live_logging=False,
10301029
)
1031-
1032-
1033-
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
1034-
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
1035-
def test_deploy_with_name_and_task(sagemaker_session):
1036-
sagemaker_session.sagemaker_config = {}
1037-
1038-
model = Model(
1039-
MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
1040-
)
1041-
1042-
endpoint_name = "testing-task-input"
1043-
predictor = model.deploy(
1044-
endpoint_name=endpoint_name,
1045-
instance_type=INSTANCE_TYPE,
1046-
initial_instance_count=INSTANCE_COUNT,
1047-
)
1048-
1049-
sagemaker_session.create_model.assert_called_with(
1050-
name=MODEL_IMAGE,
1051-
role=ROLE,
1052-
task=HF_TASK
1053-
)
1054-
1055-
assert isinstance(predictor, sagemaker.predictor.Predictor)
1056-
assert predictor.endpoint_name == endpoint_name
1057-
assert predictor.sagemaker_session == sagemaker_session
1058-
1059-
1060-
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
1061-
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
1062-
def test_deploy_with_name_and_without_task(sagemaker_session):
1063-
sagemaker_session.sagemaker_config = {}
1064-
1065-
model = Model(
1066-
MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
1067-
)
1068-
1069-
endpoint_name = "testing-without-task-input"
1070-
predictor = model.deploy(
1071-
endpoint_name=endpoint_name,
1072-
instance_type=INSTANCE_TYPE,
1073-
initial_instance_count=INSTANCE_COUNT,
1074-
)
1075-
1076-
sagemaker_session.create_model.assert_called_with(
1077-
name=MODEL_IMAGE,
1078-
role=ROLE,
1079-
task=None,
1080-
)
1081-
1082-
assert isinstance(predictor, sagemaker.predictor.Predictor)
1083-
assert predictor.endpoint_name == endpoint_name
1084-
assert predictor.sagemaker_session == sagemaker_session

0 commit comments

Comments
 (0)