Skip to content

Commit 0730ae1

Browse files
committed
fix altconfig hubcontent and reenable integ test
1 parent ecd89b9 commit 0730ae1

File tree

4 files changed

+51
-26
lines changed

4 files changed

+51
-26
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ def get_model_specs(
287287
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
288288
)
289289
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
290-
291-
if hub_arn:
290+
291+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
292+
if hub_arn:
292293
try:
293294
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
294295
hub_arn=hub_arn, model_name=model_id, version=version
@@ -308,11 +309,22 @@ def get_model_specs(
308309
hub_model_arn = construct_hub_model_arn_from_inputs(
309310
hub_arn=hub_arn, model_name=model_id, version=version
310311
)
311-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
312-
hub_model_arn=hub_model_arn
313-
)
314-
model_specs.set_hub_content_type(HubContentType.MODEL)
315-
return model_specs
312+
313+
# Failed to describe ModelReference, try with Model
314+
try:
315+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
316+
hub_model_arn=hub_model_arn
317+
)
318+
model_specs.set_hub_content_type(HubContentType.MODEL)
319+
320+
return model_specs
321+
except Exception as ex:
322+
# Failed with both, throw a custom error message
323+
raise Exception(
324+
f"Cannot get details for {model_id} in Hub {hub_arn}. \
325+
{model_id} does not exist as a Model or ModelReference: \n"
326+
+ str(ex)
327+
)
316328

317329
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
318330
model_id=model_id, version_str=version, model_type=model_type

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None:
272272
def describe_model(
273273
self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None
274274
) -> DescribeHubContentResponse:
275-
"""Describe model in the SageMaker Hub."""
275+
"""Describe Model or ModelReference in a Hub."""
276+
hub_name = self.hub_name if not hub_name else hub_name
277+
278+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
276279
try:
277280
model_version = get_hub_model_version(
278281
hub_model_name=model_name,
279282
hub_model_type=HubContentType.MODEL_REFERENCE.value,
280-
hub_name=self.hub_name if not hub_name else hub_name,
283+
hub_name=hub_name,
281284
sagemaker_session=self._sagemaker_session,
282285
hub_model_version=model_version,
283286
)
284287

285288
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
286-
hub_name=self.hub_name if not hub_name else hub_name,
289+
hub_name=hub_name,
287290
hub_content_name=model_name,
288291
hub_content_version=model_version,
289292
hub_content_type=HubContentType.MODEL_REFERENCE.value,
@@ -294,19 +297,30 @@ def describe_model(
294297
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
295298
+ str(ex)
296299
)
297-
model_version = get_hub_model_version(
298-
hub_model_name=model_name,
299-
hub_model_type=HubContentType.MODEL.value,
300-
hub_name=self.hub_name if not hub_name else hub_name,
301-
sagemaker_session=self._sagemaker_session,
302-
hub_model_version=model_version,
303-
)
304300

305-
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
306-
hub_name=self.hub_name if not hub_name else hub_name,
307-
hub_content_name=model_name,
308-
hub_content_version=model_version,
309-
hub_content_type=HubContentType.MODEL.value,
310-
)
301+
# Failed to describe ModelReference, try with Model
302+
try:
303+
model_version = get_hub_model_version(
304+
hub_model_name=model_name,
305+
hub_model_type=HubContentType.MODEL.value,
306+
hub_name=hub_name,
307+
sagemaker_session=self._sagemaker_session,
308+
hub_model_version=model_version,
309+
)
310+
311+
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
312+
hub_name=hub_name,
313+
hub_content_name=model_name,
314+
hub_content_version=model_version,
315+
hub_content_type=HubContentType.MODEL.value,
316+
)
317+
318+
except Exception as ex:
319+
# Failed with both, throw a custom error message
320+
raise Exception(
321+
f"Cannot get details for {model_name} in Hub {hub_name}. \
322+
{model_name} does not exist as a Model or ModelReference in {hub_name}: \n"
323+
+ str(ex)
324+
)
311325

312326
return DescribeHubContentResponse(hub_content_description)

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,9 +1363,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13631363
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
13641364
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
13651365
JumpStartPredictorSpecs(
1366-
json_obj["predictor_specs"], is_hub_content=self._is_hub_content
1366+
json_obj.get("predictor_specs"), is_hub_content=self._is_hub_content
13671367
)
1368-
if "predictor_specs" in json_obj
1368+
if json_obj.get("predictor_specs")
13691369
else None
13701370
)
13711371
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_jumpstart_hub_gated_model(setup, add_model_references):
122122
assert response is not None
123123

124124

125-
@pytest.mark.skip(reason="blocking PR checks and release pipeline.")
126125
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references):
127126

128127
model_id = "meta-textgeneration-llama-2-7b"

0 commit comments

Comments
 (0)