Skip to content

fix: altconfig hubcontent and reenable integ test #5051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def get_model_specs(
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)

# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
if hub_arn:
try:
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
Expand All @@ -308,11 +309,22 @@ def get_model_specs(
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
hub_model_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL)
return model_specs

# Failed to describe ModelReference, try with Model
try:
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
hub_model_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL)

return model_specs
except Exception as ex:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we catch a more specific exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you suggest that is getting thrown from botocore?

# Failed with both, throw a custom error message
raise RuntimeError(
f"Cannot get details for {model_id} in Hub {hub_arn}. \
{model_id} does not exist as a Model or ModelReference: \n"
+ str(ex)
)

return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, version_str=version, model_type=model_type
Expand Down
48 changes: 32 additions & 16 deletions src/sagemaker/jumpstart/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None:
def describe_model(
self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None
) -> DescribeHubContentResponse:
"""Describe model in the SageMaker Hub."""
"""Describe Model or ModelReference in a Hub."""
hub_name = hub_name or self.hub_name

# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
try:
model_version = get_hub_model_version(
hub_model_name=model_name,
hub_model_type=HubContentType.MODEL_REFERENCE.value,
hub_name=self.hub_name if not hub_name else hub_name,
hub_name=hub_name,
sagemaker_session=self._sagemaker_session,
hub_model_version=model_version,
)

hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=self.hub_name if not hub_name else hub_name,
hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL_REFERENCE.value,
Expand All @@ -294,19 +297,32 @@ def describe_model(
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
+ str(ex)
)
model_version = get_hub_model_version(
hub_model_name=model_name,
hub_model_type=HubContentType.MODEL.value,
hub_name=self.hub_name if not hub_name else hub_name,
sagemaker_session=self._sagemaker_session,
hub_model_version=model_version,
)

hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=self.hub_name if not hub_name else hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL.value,
)
# Failed to describe ModelReference, try with Model
try:
model_version = get_hub_model_version(
hub_model_name=model_name,
hub_model_type=HubContentType.MODEL.value,
hub_name=hub_name,
sagemaker_session=self._sagemaker_session,
hub_model_version=model_version,
)

hub_content_description: Dict[str, Any] = (
self._sagemaker_session.describe_hub_content(
hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL.value,
)
)

except Exception as ex:
# Failed with both, throw a custom error message
raise RuntimeError(
f"Cannot get details for {model_name} in Hub {hub_name}. \
{model_name} does not exist as a Model or ModelReference in {hub_name}: \n"
+ str(ex)
)

return DescribeHubContentResponse(hub_content_description)
13 changes: 11 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,9 +1363,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
JumpStartPredictorSpecs(
json_obj["predictor_specs"], is_hub_content=self._is_hub_content
json_obj.get("predictor_specs"),
is_hub_content=self._is_hub_content,
)
if "predictor_specs" in json_obj
if json_obj.get("predictor_specs")
else None
)
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
Expand Down Expand Up @@ -1501,6 +1502,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
"incremental_training_supported",
]

# Map of HubContent fields that map to custom names in MetadataBaseFields
CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}

__slots__ = slots + JumpStartMetadataBaseFields.__slots__

def __init__(
Expand Down Expand Up @@ -1532,6 +1536,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if field in self.__slots__:
setattr(self, field, json_obj[field])

# Handle custom fields
for custom_field, field in self.CUSTOM_FIELD_MAP.items():
if custom_field in json_obj:
setattr(self, field, json_obj.get(custom_field))


class JumpStartMetadataConfig(JumpStartDataHolderType):
"""Data class of JumpStart metadata config."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def test_jumpstart_hub_gated_model(setup, add_model_references):
assert response is not None


@pytest.mark.skip(reason="blocking PR checks and release pipeline.")
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references):

model_id = "meta-textgeneration-llama-2-7b"
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/sagemaker/jumpstart/hub/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se
)


@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json")
def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session):
mock_describe_hub_content_response.return_value = Mock()
mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions
mock_list_hub_content_versions.return_value = {
"HubContentSummaries": [
{"HubContentVersion": "1.0"},
{"HubContentVersion": "2.0"},
{"HubContentVersion": "3.0"},
]
}
mock_describe_hub_content = sagemaker_session.describe_hub_content
mock_describe_hub_content.side_effect = [
Exception("Some exception"),
{"HubContentName": "test-model", "HubContentVersion": "3.0"},
]

hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)

with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version:
mock_get_hub_model_version.return_value = "3.0"

hub.describe_model("test-model")

mock_describe_hub_content.asssert_called_times(2)
mock_describe_hub_content.assert_called_with(
hub_name=HUB_NAME,
hub_content_name="test-model",
hub_content_version="3.0",
hub_content_type="Model",
)


def test_create_hub_content_reference(sagemaker_session):
hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)
model_name = "mock-model-one-huggingface"
Expand Down