diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 20a2d16c15..2ed2deb803 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -288,6 +288,7 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model if hub_arn: try: hub_model_arn = construct_hub_model_reference_arn_from_inputs( @@ -308,11 +309,22 @@ def get_model_specs( hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model( - hub_model_arn=hub_model_arn - ) - model_specs.set_hub_content_type(HubContentType.MODEL) - return model_specs + + # Failed to describe ModelReference, try with Model + try: + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + + return model_specs + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_id} in Hub {hub_arn}. \ + {model_id} does not exist as a Model or ModelReference: \n" + + str(ex) + ) return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index bc42eebea0..402b2ce534 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None: def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - """Describe model in the SageMaker Hub.""" + """Describe Model or ModelReference in a Hub.""" + hub_name = hub_name or self.hub_name + + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, @@ -294,19 +297,32 @@ def describe_model( "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) - model_version = get_hub_model_version( - hub_model_name=model_name, - hub_model_type=HubContentType.MODEL.value, - hub_name=self.hub_name if not hub_name else hub_name, - sagemaker_session=self._sagemaker_session, - hub_model_version=model_version, - ) - hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, - ) + # Failed to describe ModelReference, try with Model + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = ( + self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + ) + + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_name} in Hub {hub_name}. \ + {model_name} does not exist as a Model or ModelReference in {hub_name}: \n" + + str(ex) + ) return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 3dee2b3553..908241812e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1363,9 +1363,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( JumpStartPredictorSpecs( - json_obj["predictor_specs"], is_hub_content=self._is_hub_content + json_obj.get("predictor_specs"), + is_hub_content=self._is_hub_content, ) - if "predictor_specs" in json_obj + if json_obj.get("predictor_specs") else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -1501,6 +1502,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): "incremental_training_supported", ] + # Map of HubContent fields that map to custom names in MetadataBaseFields + CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"} + __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( @@ -1532,6 +1536,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if field in self.__slots__: setattr(self, field, json_obj[field]) + # Handle custom fields + for custom_field, field in self.CUSTOM_FIELD_MAP.items(): + if custom_field in json_obj: + setattr(self, field, json_obj.get(custom_field)) + class JumpStartMetadataConfig(JumpStartDataHolderType): """Data class of JumpStart metadata config.""" diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index c378520196..e8e5cc0942 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -122,7 +122,6 @@ def test_jumpstart_hub_gated_model(setup, add_model_references): assert response is not None -@pytest.mark.skip(reason="blocking PR checks and release pipeline.") def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): model_id = "meta-textgeneration-llama-2-7b" diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 8522b33bc3..06f5473322 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -192,6 +192,39 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se ) +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + mock_describe_hub_content = sagemaker_session.describe_hub_content + mock_describe_hub_content.side_effect = [ + Exception("Some exception"), + {"HubContentName": "test-model", "HubContentVersion": "3.0"}, + ] + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_describe_hub_content.asssert_called_times(2) + mock_describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + def test_create_hub_content_reference(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface"