Skip to content

feat: Marketplace model support in HubService #4916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 31, 2024
29 changes: 19 additions & 10 deletions src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
class HubModelDocument(HubDataHolderType):
"""Data class for model type HubContentDocument from session.describe_hub_content()."""

SCHEMA_VERSION = "2.2.0"
SCHEMA_VERSION = "2.3.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sanity-check: will the pySDK still understand previous schema?


__slots__ = [
"url",
"min_sdk_version",
"training_supported",
"model_types",
"capabilities",
"incremental_training_supported",
"dynamic_container_deployment_supported",
"hosting_ecr_uri",
Expand All @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
"hosting_use_script_uri",
"hosting_eula_uri",
"hosting_model_package_arn",
"model_subscription_link",
"inference_configs",
"inference_config_components",
"inference_config_rankings",
Expand Down Expand Up @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
"""
self.url: str = json_obj["Url"]
self.min_sdk_version: str = json_obj["MinSdkVersion"]
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
self.hosting_script_uri = json_obj["HostingScriptUri"]
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
self.url: str = json_obj.get("Url")
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
self.hosting_script_uri = json_obj.get("HostingScriptUri")
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
for env_variable in json_obj["InferenceEnvironmentVariables"]
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
]
self.training_supported: bool = bool(json_obj["TrainingSupported"])
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
self.incremental_training_supported: bool = bool(
json_obj.get("IncrementalTrainingSupported")
)
self.dynamic_container_deployment_supported: Optional[bool] = (
bool(json_obj.get("DynamicContainerDeploymentSupported"))
if json_obj.get("DynamicContainerDeploymentSupported")
Expand All @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")

self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")

self.inference_config_rankings = self._get_config_rankings(json_obj)
self.inference_config_components = self._get_config_components(json_obj)
self.inference_configs = self._get_configs(json_obj)
Expand Down
11 changes: 5 additions & 6 deletions src/sagemaker/jumpstart/hub/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@


def camel_to_snake(camel_case_string: str) -> str:
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
if "-" in snake_case_string:
# remove any hyphen from the string for accurate conversion.
snake_case_string = snake_case_string.replace("-", "")
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
"""Converts PascalCase to snake_case_string using a regex.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent regex and method name.

Careful with renaming the method though, it would be backward incompatible.

Copy link
Contributor Author

@chrstfu chrstfu Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, let me rewrite that docstring. This regex can handle camelCase as well as PascalCase


This regex cannot handle whitespace ("PascalString TwoWords")
"""
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()


def snake_to_upper_camel(snake_case_string: str) -> str:
Expand Down
26 changes: 17 additions & 9 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
hub_model_document: HubModelDocument = response.hub_content_document
specs["url"] = hub_model_document.url
specs["min_sdk_version"] = hub_model_document.min_sdk_version
specs["model_types"] = hub_model_document.model_types
specs["capabilities"] = hub_model_document.capabilities
specs["training_supported"] = bool(hub_model_document.training_supported)
specs["incremental_training_supported"] = bool(
hub_model_document.incremental_training_supported
Expand All @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
specs["inference_config_components"] = hub_model_document.inference_config_components
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings

hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key
if hub_model_document.hosting_artifact_uri:
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri

if hub_model_document.hosting_script_uri:
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key

specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
specs["inference_vulnerable"] = False
specs["inference_dependencies"] = hub_model_document.inference_dependencies
Expand Down Expand Up @@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response(
if hub_model_document.hosting_model_package_arn:
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}

specs["model_subscription_link"] = hub_model_document.model_subscription_link

specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri

specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants
Expand Down
68 changes: 65 additions & 3 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""This module contains utilities related to SageMaker JumpStart Hub."""
from __future__ import absolute_import
import re
from typing import Optional
from typing import Optional, List, Any
from sagemaker.jumpstart.hub.types import S3ObjectLocation
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
Expand All @@ -23,6 +23,14 @@
from sagemaker.jumpstart import constants
from packaging.specifiers import SpecifierSet, InvalidSpecifier

PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"


def _convert_str_to_optional(string: str) -> Optional[str]:
if string == "None":
string = None
return string


def get_info_from_hub_resource_arn(
arn: str,
Expand All @@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
hub_name = match.group(4)
hub_content_type = match.group(5)
hub_content_name = match.group(6)
hub_content_version = match.group(7)
hub_content_version = _convert_str_to_optional(match.group(7))

return HubArnExtractedInfo(
partition=partition,
Expand Down Expand Up @@ -194,10 +202,14 @@ def get_hub_model_version(
hub_model_version: Optional[str] = None,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Returns available Jumpstart hub model version
"""Returns available Jumpstart hub model version.

It will attempt both a semantic HubContent version search and Marketplace version search.
If the Marketplace version is also semantic, this function will default to HubContent version.

Raises:
ClientError: If the specified model is not found in the hub.
KeyError: If the specified model version is not found.
"""

try:
Expand All @@ -207,6 +219,22 @@ def get_hub_model_version(
except Exception as ex:
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")

try:
return _get_hub_model_version_for_open_weight_version(
hub_content_summaries, hub_model_version
)
except KeyError:
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
hub_content_summaries, hub_model_version
)
if marketplace_hub_content_version:
return marketplace_hub_content_version
raise


def _get_hub_model_version_for_open_weight_version(
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
) -> str:
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]

if hub_model_version == "*" or hub_model_version is None:
Expand All @@ -222,3 +250,37 @@ def get_hub_model_version(
hub_model_version = str(max(available_versions_filtered))

return hub_model_version


def _get_hub_model_version_for_marketplace_version(
hub_content_summaries: List[Any], marketplace_version: str
) -> Optional[str]:
"""Returns the HubContent version associated with the Marketplace version.

This function will check within the HubContentSearchKeywords for the proprietary version.
"""
for model in hub_content_summaries:
model_search_keywords = model.get("HubContentSearchKeywords", [])
if _hub_search_keywords_contains_marketplace_version(
model_search_keywords, marketplace_version
):
return model.get("HubContentVersion")

return None


def _hub_search_keywords_contains_marketplace_version(
model_search_keywords: List[str], marketplace_version: str
) -> bool:
proprietary_version_keyword = next(
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
)

if not proprietary_version_keyword:
return False

proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
if proprietary_version == marketplace_version:
return True

return False
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"url",
"version",
"min_sdk_version",
"model_types",
"capabilities",
"incremental_training_supported",
"hosting_ecr_specs",
"hosting_ecr_uri",
Expand Down Expand Up @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj.get("incremental_training_supported", False)
)
if self._is_hub_content:
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
self.model_types: Optional[List[str]] = json_obj.get("model_types")
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
self._non_serializable_slots.append("hosting_ecr_specs")
else:
Expand Down
42 changes: 41 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,16 @@ def validate_model_id_and_get_type(
if not isinstance(model_id, str):
return None
if hub_arn:
return None
model_types = _validate_hub_service_model_id_and_get_type(
model_id=model_id,
hub_arn=hub_arn,
region=region,
model_version=model_version,
sagemaker_session=sagemaker_session,
)
return (
model_types[0] if model_types else None
) # Currently this function only supports one model type

s3_client = sagemaker_session.s3_client if sagemaker_session else None
region = region or constants.JUMPSTART_DEFAULT_REGION_NAME
Expand All @@ -881,6 +890,37 @@ def validate_model_id_and_get_type(
return None


def _validate_hub_service_model_id_and_get_type(
model_id: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, what happens if model_id is None?

Copy link
Contributor Author

@chrstfu chrstfu Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that should be just str. Thanks, good catch!

hub_arn: str,
region: Optional[str] = None,
model_version: Optional[str] = None,
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[enums.JumpStartModelType]:
"""Returns a list of JumpStartModelType based off the HubContent.

Only returns valid JumpStartModelType. Returns an empty array if none are found.
"""
hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
region=region,
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
sagemaker_session=sagemaker_session,
)

hub_content_model_types = []
model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", [])
model_types = model_types_field if model_types_field else []
for model_type in model_types:
try:
hub_content_model_types.append(enums.JumpStartModelType[model_type])
except ValueError:
continue

return hub_content_model_types


def _extract_value_from_list_of_tags(
tag_keys: List[str],
list_tags_result: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
"ap-southeast-2",
}

TEST_HUB_WITH_REFERENCE = "mock-hub-name"


def test_non_prepacked_jumpstart_model(setup):

Expand Down
29 changes: 12 additions & 17 deletions tests/integ/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,18 @@ def get_sm_session() -> Session:
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME))


# def get_sm_session_with_override() -> Session:
# # [TODO]: Remove service endpoint override before GA
# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG)
# boto_session = boto3.Session(region_name="us-west-2")
# sagemaker = boto3.client(
# service_name="sagemaker-internal",
# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com",
# )
# sagemaker_runtime = boto3.client(
# service_name="runtime.maeve",
# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com",
# )
# return Session(
# boto_session=boto_session,
# sagemaker_client=sagemaker,
# sagemaker_runtime_client=sagemaker_runtime,
# )
def get_sm_session_with_override() -> Session:
# [TODO]: Remove service endpoint override before GA
# boto3.set_stream_logger(name='botocore', level=logging.DEBUG)
boto_session = boto3.Session(region_name="us-west-2")
sagemaker = boto3.client(
service_name="sagemaker",
endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you merge this to main? Block the release if that's the case.

Btw this code should never have been pushed to the public repo...

)
return Session(
boto_session=boto_session,
sagemaker_client=sagemaker,
)


def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict:
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9178,6 +9178,7 @@
"TrainingArtifactS3DataType": "S3Prefix",
"TrainingArtifactCompressionType": "None",
"TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501
"ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"],
"Hyperparameters": [
{
"Name": "peft_type",
Expand Down
Loading