Skip to content

feat: Marketplace model support in HubService #4916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 31, 2024
29 changes: 19 additions & 10 deletions src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
class HubModelDocument(HubDataHolderType):
"""Data class for model type HubContentDocument from session.describe_hub_content()."""

SCHEMA_VERSION = "2.2.0"
SCHEMA_VERSION = "2.3.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sanity-check: will the pySDK still understand previous schema?


__slots__ = [
"url",
"min_sdk_version",
"training_supported",
"model_types",
"capabilities",
"incremental_training_supported",
"dynamic_container_deployment_supported",
"hosting_ecr_uri",
Expand All @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
"hosting_use_script_uri",
"hosting_eula_uri",
"hosting_model_package_arn",
"model_subscription_link",
"inference_configs",
"inference_config_components",
"inference_config_rankings",
Expand Down Expand Up @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
"""
self.url: str = json_obj["Url"]
self.min_sdk_version: str = json_obj["MinSdkVersion"]
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
self.hosting_script_uri = json_obj["HostingScriptUri"]
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
self.url: str = json_obj.get("Url")
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
self.hosting_script_uri = json_obj.get("HostingScriptUri")
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
for env_variable in json_obj["InferenceEnvironmentVariables"]
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
]
self.training_supported: bool = bool(json_obj["TrainingSupported"])
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
self.incremental_training_supported: bool = bool(
json_obj.get("IncrementalTrainingSupported")
)
self.dynamic_container_deployment_supported: Optional[bool] = (
bool(json_obj.get("DynamicContainerDeploymentSupported"))
if json_obj.get("DynamicContainerDeploymentSupported")
Expand All @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")

self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")

self.inference_config_rankings = self._get_config_rankings(json_obj)
self.inference_config_components = self._get_config_components(json_obj)
self.inference_configs = self._get_configs(json_obj)
Expand Down
13 changes: 6 additions & 7 deletions src/sagemaker/jumpstart/hub/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
from typing import Any, Dict, List, Optional


def camel_to_snake(camel_case_string: str) -> str:
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
if "-" in snake_case_string:
# remove any hyphen from the string for accurate conversion.
snake_case_string = snake_case_string.replace("-", "")
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
def pascal_to_snake(camel_case_string: str) -> str:
"""Converts PascalCase to snake_case_string using a regex.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent regex and method name.

Careful with renaming the method though, it would be backward incompatible.

Copy link
Contributor Author

@chrstfu chrstfu Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, let me rewrite that docstring. This regex can handle camelCase as well as PascalCase


This regex cannot handle whitespace ("PascalString TwoWords")
"""
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()


def snake_to_upper_camel(snake_case_string: str) -> str:
Expand Down
32 changes: 20 additions & 12 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
HubModelDocument,
)
from sagemaker.jumpstart.hub.parser_utils import (
camel_to_snake,
pascal_to_snake,
snake_to_upper_camel,
walk_and_apply_json,
)
Expand Down Expand Up @@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
arg_keys = []

if naming_convention == NamingConventionType.SNAKE_CASE:
arg_keys = [camel_to_snake(key) for key in arg_keys]
arg_keys = [pascal_to_snake(key) for key in arg_keys]
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
return arg_keys
else:
Expand Down Expand Up @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
hub_model_document: HubModelDocument = response.hub_content_document
specs["url"] = hub_model_document.url
specs["min_sdk_version"] = hub_model_document.min_sdk_version
specs["model_types"] = hub_model_document.model_types
specs["capabilities"] = hub_model_document.capabilities
specs["training_supported"] = bool(hub_model_document.training_supported)
specs["incremental_training_supported"] = bool(
hub_model_document.incremental_training_supported
Expand All @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
specs["inference_config_components"] = hub_model_document.inference_config_components
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings

hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key
if hub_model_document.hosting_artifact_uri:
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri

if hub_model_document.hosting_script_uri:
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key

specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
specs["inference_vulnerable"] = False
specs["inference_dependencies"] = hub_model_document.inference_dependencies
Expand Down Expand Up @@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
default_payloads: Dict[str, Any] = {}
if hub_model_document.default_payloads is not None:
for alias, payload in hub_model_document.default_payloads.items():
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
specs["default_payloads"] = default_payloads
specs["gated_bucket"] = hub_model_document.gated_bucket
specs["inference_volume_size"] = hub_model_document.inference_volume_size
Expand All @@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response(
if hub_model_document.hosting_model_package_arn:
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}

specs["model_subscription_link"] = hub_model_document.model_subscription_link

specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri

specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants
Expand Down
69 changes: 66 additions & 3 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""This module contains utilities related to SageMaker JumpStart Hub."""
from __future__ import absolute_import
import re
from typing import Optional
from typing import Optional, List, Any
from sagemaker.jumpstart.hub.types import S3ObjectLocation
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
Expand All @@ -23,6 +23,14 @@
from sagemaker.jumpstart import constants
from packaging.specifiers import SpecifierSet, InvalidSpecifier

PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"


def _convert_str_to_optional(string: str) -> Optional[str]:
if string == "None":
string = None
return string


def get_info_from_hub_resource_arn(
arn: str,
Expand All @@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
hub_name = match.group(4)
hub_content_type = match.group(5)
hub_content_name = match.group(6)
hub_content_version = match.group(7)
hub_content_version = _convert_str_to_optional(match.group(7))

return HubArnExtractedInfo(
partition=partition,
Expand Down Expand Up @@ -194,10 +202,14 @@ def get_hub_model_version(
hub_model_version: Optional[str] = None,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Returns available Jumpstart hub model version
"""Returns available Jumpstart hub model version.

It will attempt both a semantic HubContent version search and Marketplace version search.
If the Marketplace version is also semantic, this function will default to HubContent version.

Raises:
ClientError: If the specified model is not found in the hub.
KeyError: If the specified model version is not found.
"""

try:
Expand All @@ -207,6 +219,23 @@ def get_hub_model_version(
except Exception as ex:
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")

marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
hub_content_summaries, hub_model_version
)

try:
return _get_hub_model_version_for_open_weight_version(
hub_content_summaries, hub_model_version
)
except KeyError as e:
if marketplace_hub_content_version:
return marketplace_hub_content_version
raise e


def _get_hub_model_version_for_open_weight_version(
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
) -> str:
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]

if hub_model_version == "*" or hub_model_version is None:
Expand All @@ -222,3 +251,37 @@ def get_hub_model_version(
hub_model_version = str(max(available_versions_filtered))

return hub_model_version


def _get_hub_model_version_for_marketplace_version(
hub_content_summaries: List[Any], marketplace_version: str
) -> Optional[str]:
"""Returns the HubContent version associated with the Marketplace version.

This function will check within the HubContentSearchKeywords for the proprietary version.
"""
for model in hub_content_summaries:
model_search_keywords = model.get("HubContentSearchKeywords", [])
if _hub_search_keywords_contains_marketplace_version(
model_search_keywords, marketplace_version
):
return model.get("HubContentVersion")

return None


def _hub_search_keywords_contains_marketplace_version(
model_search_keywords: List[str], marketplace_version: str
) -> bool:
proprietary_version_keyword = next(
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
)

if not proprietary_version_keyword:
return False

proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
if proprietary_version == marketplace_version:
return True

return False
30 changes: 17 additions & 13 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.enums import EndpointType
from sagemaker.jumpstart.hub.parser_utils import (
camel_to_snake,
pascal_to_snake,
walk_and_apply_json,
)

Expand Down Expand Up @@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)

self.framework = json_obj.get("framework")
self.framework_version = json_obj.get("framework_version")
Expand Down Expand Up @@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.name = json_obj["name"]
self.type = json_obj["type"]
self.default = json_obj["default"]
Expand Down Expand Up @@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
"""
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.name = json_obj["name"]
self.type = json_obj["type"]
self.default = json_obj["default"]
Expand Down Expand Up @@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.default_content_type = json_obj["default_content_type"]
self.supported_content_types = json_obj["supported_content_types"]
self.default_accept_type = json_obj["default_accept_type"]
Expand Down Expand Up @@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.raw_payload = json_obj
self.content_type = json_obj["content_type"]
self.body = json_obj.get("body")
Expand Down Expand Up @@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]])
if response is None:
return

response = walk_and_apply_json(response, camel_to_snake)
response = walk_and_apply_json(response, pascal_to_snake)
self.aliases: Optional[dict] = response.get("aliases")
self.regional_aliases = None
self.variants: Optional[dict] = response.get("variants")
Expand Down Expand Up @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
spec (Dict[str, Any]): Dictionary representation of training config ranking.
"""
if is_hub_content:
spec = walk_and_apply_json(spec, camel_to_snake)
spec = walk_and_apply_json(spec, pascal_to_snake)
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
Expand All @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"url",
"version",
"min_sdk_version",
"model_types",
"capabilities",
"incremental_training_supported",
"hosting_ecr_specs",
"hosting_ecr_uri",
Expand Down Expand Up @@ -1278,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.model_id: str = json_obj.get("model_id")
self.url: str = json_obj.get("url")
self.version: str = json_obj.get("version")
Expand All @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj.get("incremental_training_supported", False)
)
if self._is_hub_content:
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
self.model_types: Optional[List[str]] = json_obj.get("model_types")
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
self._non_serializable_slots.append("hosting_ecr_specs")
else:
Expand Down Expand Up @@ -1505,7 +1509,7 @@ def __init__(
ValueError: If the component field is invalid.
"""
if is_hub_content:
component = walk_and_apply_json(component, camel_to_snake)
component = walk_and_apply_json(component, pascal_to_snake)
self.component_name = component_name
super().__init__(component, is_hub_content)
self.from_json(component)
Expand Down Expand Up @@ -1558,8 +1562,8 @@ def __init__(
The list of components that are used to construct the resolved config.
"""
if is_hub_content:
config = walk_and_apply_json(config, camel_to_snake)
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
config = walk_and_apply_json(config, pascal_to_snake)
base_fields = walk_and_apply_json(base_fields, pascal_to_snake)
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
Expand Down Expand Up @@ -1725,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""
super().from_json(json_obj)
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
{
component_name: JumpStartConfigComponent(component_name, component)
Expand Down
Loading
Loading