Skip to content

Commit 32d44fb

Browse files
samrudsroot
authored and
root
committed
Fix: Add Image URI overrides for transformers models (aws#4693)
* Fix: Add Image URI overrides for transformers models * Increase coverage * Fix formatting
1 parent 9402a87 commit 32d44fb

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

src/sagemaker/serve/builder/transformers_builder.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ def _prepare_for_mode(self):
7878
"""Abstract method"""
7979

8080
def _create_transformers_model(self) -> Type[Model]:
81+
"""Initializes HF model with or without image_uri"""
82+
if self.image_uri is None:
83+
pysdk_model = self._get_hf_metadata_create_model()
84+
else:
85+
pysdk_model = HuggingFaceModel(
86+
image_uri=self.image_uri,
87+
vpc_config=self.vpc_config,
88+
env=self.env_vars,
89+
role=self.role_arn,
90+
sagemaker_session=self.sagemaker_session,
91+
)
92+
93+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
94+
95+
self._original_deploy = pysdk_model.deploy
96+
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
97+
return pysdk_model
98+
99+
def _get_hf_metadata_create_model(self) -> Type[Model]:
81100
"""Initializes the model after fetching image
82101
83102
1. Get the metadata for deciding framework
@@ -132,22 +151,21 @@ def _create_transformers_model(self) -> Type[Model]:
132151
vpc_config=self.vpc_config,
133152
)
134153

135-
if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
154+
if self.mode == Mode.LOCAL_CONTAINER:
136155
self.image_uri = pysdk_model.serving_image_uri(
137156
self.sagemaker_session.boto_region_name, "local"
138157
)
139-
elif not self.image_uri:
158+
else:
140159
self.image_uri = pysdk_model.serving_image_uri(
141160
self.sagemaker_session.boto_region_name, self.instance_type
142161
)
143162

144-
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
163+
if pysdk_model is None or self.image_uri is None:
164+
raise ValueError("PySDK model unable to be created, try overriding image_uri")
145165

146166
if not pysdk_model.image_uri:
147167
pysdk_model.image_uri = self.image_uri
148168

149-
self._original_deploy = pysdk_model.deploy
150-
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
151169
return pysdk_model
152170

153171
@_capture_telemetry("transformers.deploy")

tests/integ/sagemaker/serve/test_serve_transformers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,4 @@ def test_pytorch_transformers_sagemaker_endpoint(
127127
logger.exception(caught_ex)
128128
assert (
129129
False
130-
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"
130+
), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test"

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

+26
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,29 @@ def test_image_uri_override(
144144

145145
with self.assertRaises(ValueError) as _:
146146
model.deploy(mode=Mode.IN_PROCESS)
147+
148+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
149+
@patch(
150+
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
151+
return_value="ml.g5.24xlarge",
152+
)
153+
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
154+
@patch(
155+
"sagemaker.huggingface.llm_utils.get_huggingface_model_metadata",
156+
return_value=None,
157+
)
158+
def test_failure_hf_md(
159+
self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers
160+
):
161+
builder = ModelBuilder(
162+
model=mock_model_id,
163+
schema_builder=mock_schema_builder,
164+
mode=Mode.LOCAL_CONTAINER,
165+
)
166+
167+
builder._prepare_for_mode = MagicMock()
168+
builder._prepare_for_mode.side_effect = None
169+
170+
builder.build()
171+
172+
mock_build_for_transformers.assert_called_once()

0 commit comments

Comments
 (0)