Skip to content

Commit af7fb97

Browse files
bencrabtreemalav-shastrimalavhs
authored
fix: altconfig hubcontent and reenable integ test (#5051)
* fix altconfig hubcontent and reenable integ test * linting * update exception thrown * feat: Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK (#5050) Co-authored-by: malavhs <[email protected]> * add test * update predictor spec accessor * lint * set custom field from HCD config to model spec data class * lint * remove logs * last update --------- Co-authored-by: Malav Shastri <[email protected]> Co-authored-by: malavhs <[email protected]>
1 parent efd6983 commit af7fb97

File tree

5 files changed

+93
-24
lines changed

5 files changed

+93
-24
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def get_model_specs(
288288
)
289289
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
290290

291+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
291292
if hub_arn:
292293
try:
293294
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
@@ -308,11 +309,22 @@ def get_model_specs(
308309
hub_model_arn = construct_hub_model_arn_from_inputs(
309310
hub_arn=hub_arn, model_name=model_id, version=version
310311
)
311-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
312-
hub_model_arn=hub_model_arn
313-
)
314-
model_specs.set_hub_content_type(HubContentType.MODEL)
315-
return model_specs
312+
313+
# Failed to describe ModelReference, try with Model
314+
try:
315+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
316+
hub_model_arn=hub_model_arn
317+
)
318+
model_specs.set_hub_content_type(HubContentType.MODEL)
319+
320+
return model_specs
321+
except Exception as ex:
322+
# Failed with both, throw a custom error message
323+
raise RuntimeError(
324+
f"Cannot get details for {model_id} in Hub {hub_arn}. \
325+
{model_id} does not exist as a Model or ModelReference: \n"
326+
+ str(ex)
327+
)
316328

317329
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
318330
model_id=model_id, version_str=version, model_type=model_type

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None:
272272
def describe_model(
273273
self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None
274274
) -> DescribeHubContentResponse:
275-
"""Describe model in the SageMaker Hub."""
275+
"""Describe Model or ModelReference in a Hub."""
276+
hub_name = hub_name or self.hub_name
277+
278+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
276279
try:
277280
model_version = get_hub_model_version(
278281
hub_model_name=model_name,
279282
hub_model_type=HubContentType.MODEL_REFERENCE.value,
280-
hub_name=self.hub_name if not hub_name else hub_name,
283+
hub_name=hub_name,
281284
sagemaker_session=self._sagemaker_session,
282285
hub_model_version=model_version,
283286
)
284287

285288
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
286-
hub_name=self.hub_name if not hub_name else hub_name,
289+
hub_name=hub_name,
287290
hub_content_name=model_name,
288291
hub_content_version=model_version,
289292
hub_content_type=HubContentType.MODEL_REFERENCE.value,
@@ -294,19 +297,32 @@ def describe_model(
294297
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
295298
+ str(ex)
296299
)
297-
model_version = get_hub_model_version(
298-
hub_model_name=model_name,
299-
hub_model_type=HubContentType.MODEL.value,
300-
hub_name=self.hub_name if not hub_name else hub_name,
301-
sagemaker_session=self._sagemaker_session,
302-
hub_model_version=model_version,
303-
)
304300

305-
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
306-
hub_name=self.hub_name if not hub_name else hub_name,
307-
hub_content_name=model_name,
308-
hub_content_version=model_version,
309-
hub_content_type=HubContentType.MODEL.value,
310-
)
301+
# Failed to describe ModelReference, try with Model
302+
try:
303+
model_version = get_hub_model_version(
304+
hub_model_name=model_name,
305+
hub_model_type=HubContentType.MODEL.value,
306+
hub_name=hub_name,
307+
sagemaker_session=self._sagemaker_session,
308+
hub_model_version=model_version,
309+
)
310+
311+
hub_content_description: Dict[str, Any] = (
312+
self._sagemaker_session.describe_hub_content(
313+
hub_name=hub_name,
314+
hub_content_name=model_name,
315+
hub_content_version=model_version,
316+
hub_content_type=HubContentType.MODEL.value,
317+
)
318+
)
319+
320+
except Exception as ex:
321+
# Failed with both, throw a custom error message
322+
raise RuntimeError(
323+
f"Cannot get details for {model_name} in Hub {hub_name}. \
324+
{model_name} does not exist as a Model or ModelReference in {hub_name}: \n"
325+
+ str(ex)
326+
)
311327

312328
return DescribeHubContentResponse(hub_content_description)

src/sagemaker/jumpstart/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,9 +1363,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13631363
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
13641364
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
13651365
JumpStartPredictorSpecs(
1366-
json_obj["predictor_specs"], is_hub_content=self._is_hub_content
1366+
json_obj.get("predictor_specs"),
1367+
is_hub_content=self._is_hub_content,
13671368
)
1368-
if "predictor_specs" in json_obj
1369+
if json_obj.get("predictor_specs")
13691370
else None
13701371
)
13711372
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -1501,6 +1502,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
15011502
"incremental_training_supported",
15021503
]
15031504

1505+
# Map of HubContent fields that map to custom names in MetadataBaseFields
1506+
CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}
1507+
15041508
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
15051509

15061510
def __init__(
@@ -1532,6 +1536,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
15321536
if field in self.__slots__:
15331537
setattr(self, field, json_obj[field])
15341538

1539+
# Handle custom fields
1540+
for custom_field, field in self.CUSTOM_FIELD_MAP.items():
1541+
if custom_field in json_obj:
1542+
setattr(self, field, json_obj.get(custom_field))
1543+
15351544

15361545
class JumpStartMetadataConfig(JumpStartDataHolderType):
15371546
"""Data class of JumpStart metadata config."""

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_jumpstart_hub_gated_model(setup, add_model_references):
122122
assert response is not None
123123

124124

125-
@pytest.mark.skip(reason="blocking PR checks and release pipeline.")
126125
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references):
127126

128127
model_id = "meta-textgeneration-llama-2-7b"

tests/unit/sagemaker/jumpstart/hub/test_hub.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,39 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se
192192
)
193193

194194

195+
@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json")
196+
def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session):
197+
mock_describe_hub_content_response.return_value = Mock()
198+
mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions
199+
mock_list_hub_content_versions.return_value = {
200+
"HubContentSummaries": [
201+
{"HubContentVersion": "1.0"},
202+
{"HubContentVersion": "2.0"},
203+
{"HubContentVersion": "3.0"},
204+
]
205+
}
206+
mock_describe_hub_content = sagemaker_session.describe_hub_content
207+
mock_describe_hub_content.side_effect = [
208+
Exception("Some exception"),
209+
{"HubContentName": "test-model", "HubContentVersion": "3.0"},
210+
]
211+
212+
hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)
213+
214+
with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version:
215+
mock_get_hub_model_version.return_value = "3.0"
216+
217+
hub.describe_model("test-model")
218+
219+
mock_describe_hub_content.asssert_called_times(2)
220+
mock_describe_hub_content.assert_called_with(
221+
hub_name=HUB_NAME,
222+
hub_content_name="test-model",
223+
hub_content_version="3.0",
224+
hub_content_type="Model",
225+
)
226+
227+
195228
def test_create_hub_content_reference(sagemaker_session):
196229
hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)
197230
model_name = "mock-model-one-huggingface"

0 commit comments

Comments
 (0)