Skip to content

Commit 65cc586

Browse files
authored
fix: Image URI should take precedence for HF models (#4684)
* Fix: Image URI should take precedence for HF models * Fix formatting * Fix formatting * Fix formatting * Increase coverage - UT pass
1 parent d4f3c91 commit 65cc586

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/sagemaker/serve/builder/transformers_builder.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,20 @@ def _create_transformers_model(self) -> Type[Model]:
132132
vpc_config=self.vpc_config,
133133
)
134134

135-
if self.mode == Mode.LOCAL_CONTAINER:
135+
if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
136136
self.image_uri = pysdk_model.serving_image_uri(
137137
self.sagemaker_session.boto_region_name, "local"
138138
)
139-
else:
139+
elif not self.image_uri:
140140
self.image_uri = pysdk_model.serving_image_uri(
141141
self.sagemaker_session.boto_region_name, self.instance_type
142142
)
143143

144144
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
145145

146+
if not pysdk_model.image_uri:
147+
pysdk_model.image_uri = self.image_uri
148+
146149
self._original_deploy = pysdk_model.deploy
147150
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
148151
return pysdk_model
@@ -251,13 +254,14 @@ def _set_instance(self, **kwargs):
251254
if self.mode == Mode.SAGEMAKER_ENDPOINT:
252255
if self.nb_instance_type and "instance_type" not in kwargs:
253256
kwargs.update({"instance_type": self.nb_instance_type})
257+
logger.info("Setting instance type to %s", self.nb_instance_type)
254258
elif self.instance_type and "instance_type" not in kwargs:
255259
kwargs.update({"instance_type": self.instance_type})
260+
logger.info("Setting instance type to %s", self.instance_type)
256261
else:
257262
raise ValueError(
258263
"Instance type must be provided when deploying to SageMaker Endpoint mode."
259264
)
260-
logger.info("Setting instance type to %s", self.instance_type)
261265

262266
def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
263267
"""Uses the hugging face json config to pick supported versions"""

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

+44
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
mock_schema_builder = MagicMock()
5959
mock_schema_builder.sample_input = mock_sample_input
6060
mock_schema_builder.sample_output = mock_sample_output
61+
MOCK_IMAGE_CONFIG = (
62+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
63+
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
64+
)
6165

6266

6367
class TestTransformersBuilder(unittest.TestCase):
@@ -100,3 +104,43 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
100104

101105
with self.assertRaises(ValueError) as _:
102106
model.deploy(mode=Mode.IN_PROCESS)
107+
108+
@patch(
109+
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
110+
return_value="ml.g5.24xlarge",
111+
)
112+
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
113+
def test_image_uri(
114+
self,
115+
mock_get_nb_instance,
116+
mock_telemetry,
117+
):
118+
builder = ModelBuilder(
119+
model=mock_model_id,
120+
schema_builder=mock_schema_builder,
121+
mode=Mode.LOCAL_CONTAINER,
122+
image_uri=MOCK_IMAGE_CONFIG,
123+
)
124+
125+
builder._prepare_for_mode = MagicMock()
126+
builder._prepare_for_mode.side_effect = None
127+
128+
model = builder.build()
129+
builder.serve_settings.telemetry_opt_out = True
130+
131+
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
132+
predictor = model.deploy(model_data_download_timeout=1800)
133+
134+
assert builder.image_uri == MOCK_IMAGE_CONFIG
135+
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
136+
assert isinstance(predictor, TransformersLocalModePredictor)
137+
138+
assert builder.nb_instance_type == "ml.g5.24xlarge"
139+
140+
builder._original_deploy = MagicMock()
141+
builder._prepare_for_mode.return_value = (None, {})
142+
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
143+
assert "HF_MODEL_ID" in model.env
144+
145+
with self.assertRaises(ValueError) as _:
146+
model.deploy(mode=Mode.IN_PROCESS)

0 commit comments

Comments
 (0)