Skip to content

Commit 0092ff4

Browse files
committed
feat: Marketplace model support in HubService
1 parent a14bd40 commit 0092ff4

File tree

12 files changed

+343
-68
lines changed

12 files changed

+343
-68
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sagemaker.model_metrics import ModelMetrics
4242
from sagemaker.metadata_properties import MetadataProperties
4343
from sagemaker.drift_check_baselines import DriftCheckBaselines
44-
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
44+
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
4545
from sagemaker.jumpstart.types import (
4646
HubContentType,
4747
JumpStartModelDeployKwargs,

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
451451
class HubModelDocument(HubDataHolderType):
452452
"""Data class for model type HubContentDocument from session.describe_hub_content()."""
453453

454-
SCHEMA_VERSION = "2.2.0"
454+
SCHEMA_VERSION = "2.3.0"
455455

456456
__slots__ = [
457457
"url",
458458
"min_sdk_version",
459459
"training_supported",
460+
"model_types",
461+
"capabilities",
460462
"incremental_training_supported",
461463
"dynamic_container_deployment_supported",
462464
"hosting_ecr_uri",
@@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
469471
"hosting_use_script_uri",
470472
"hosting_eula_uri",
471473
"hosting_model_package_arn",
474+
"model_subscription_link",
472475
"inference_configs",
473476
"inference_config_components",
474477
"inference_config_rankings",
@@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
549552
Args:
550553
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
551554
"""
552-
self.url: str = json_obj["Url"]
553-
self.min_sdk_version: str = json_obj["MinSdkVersion"]
554-
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
555-
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
556-
self.hosting_script_uri = json_obj["HostingScriptUri"]
557-
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
555+
self.url: str = json_obj.get("Url")
556+
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
557+
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
558+
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
559+
self.hosting_script_uri = json_obj.get("HostingScriptUri")
560+
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
558561
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
559562
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
560-
for env_variable in json_obj["InferenceEnvironmentVariables"]
563+
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
561564
]
562-
self.training_supported: bool = bool(json_obj["TrainingSupported"])
563-
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
565+
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
566+
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
567+
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
568+
self.incremental_training_supported: bool = bool(
569+
json_obj.get("IncrementalTrainingSupported")
570+
)
564571
self.dynamic_container_deployment_supported: Optional[bool] = (
565572
bool(json_obj.get("DynamicContainerDeploymentSupported"))
566573
if json_obj.get("DynamicContainerDeploymentSupported")
@@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
586593
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
587594
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
588595

596+
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
597+
589598
self.inference_config_rankings = self._get_config_rankings(json_obj)
590599
self.inference_config_components = self._get_config_components(json_obj)
591600
self.inference_configs = self._get_configs(json_obj)

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
from typing import Any, Dict, List, Optional
1919

2020

21-
def camel_to_snake(camel_case_string: str) -> str:
22-
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24-
if "-" in snake_case_string:
25-
# remove any hyphen from the string for accurate conversion.
26-
snake_case_string = snake_case_string.replace("-", "")
27-
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
21+
def pascal_to_snake(camel_case_string: str) -> str:
22+
"""Converts PascalCase to snake_case_string using a regex.
23+
24+
This regex cannot handle whitespace ("PascalString TwoWords")
25+
"""
26+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2827

2928

3029
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
HubModelDocument,
2828
)
2929
from sagemaker.jumpstart.hub.parser_utils import (
30-
camel_to_snake,
30+
pascal_to_snake,
3131
snake_to_upper_camel,
3232
walk_and_apply_json,
3333
)
@@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
8686
arg_keys = []
8787

8888
if naming_convention == NamingConventionType.SNAKE_CASE:
89-
arg_keys = [camel_to_snake(key) for key in arg_keys]
89+
arg_keys = [pascal_to_snake(key) for key in arg_keys]
9090
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
9191
return arg_keys
9292
else:
@@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
137137
hub_model_document: HubModelDocument = response.hub_content_document
138138
specs["url"] = hub_model_document.url
139139
specs["min_sdk_version"] = hub_model_document.min_sdk_version
140+
specs["model_types"] = hub_model_document.model_types
141+
specs["capabilities"] = hub_model_document.capabilities
140142
specs["training_supported"] = bool(hub_model_document.training_supported)
141143
specs["incremental_training_supported"] = bool(
142144
hub_model_document.incremental_training_supported
@@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
146148
specs["inference_config_components"] = hub_model_document.inference_config_components
147149
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
148150

149-
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
150-
hub_model_document.hosting_artifact_uri
151-
)
152-
specs["hosting_artifact_key"] = hosting_artifact_key
153-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
154-
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
155-
hub_model_document.hosting_script_uri
156-
)
157-
specs["hosting_script_key"] = hosting_script_key
151+
if hub_model_document.hosting_artifact_uri:
152+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
153+
hub_model_document.hosting_artifact_uri
154+
)
155+
specs["hosting_artifact_key"] = hosting_artifact_key
156+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
157+
158+
if hub_model_document.hosting_script_uri:
159+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
160+
hub_model_document.hosting_script_uri
161+
)
162+
specs["hosting_script_key"] = hosting_script_key
163+
158164
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
159165
specs["inference_vulnerable"] = False
160166
specs["inference_dependencies"] = hub_model_document.inference_dependencies
@@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
201207
default_payloads: Dict[str, Any] = {}
202208
if hub_model_document.default_payloads is not None:
203209
for alias, payload in hub_model_document.default_payloads.items():
204-
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
210+
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
205211
specs["default_payloads"] = default_payloads
206212
specs["gated_bucket"] = hub_model_document.gated_bucket
207213
specs["inference_volume_size"] = hub_model_document.inference_volume_size
@@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response(
220226
if hub_model_document.hosting_model_package_arn:
221227
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
222228

229+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
230+
223231
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
224232

225233
specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""This module contains utilities related to SageMaker JumpStart Hub."""
1515
from __future__ import absolute_import
1616
import re
17-
from typing import Optional
17+
from typing import Optional, List, Any
1818
from sagemaker.jumpstart.hub.types import S3ObjectLocation
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
@@ -23,6 +23,14 @@
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2525

26+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
28+
29+
def _convert_str_to_optional(string: str) -> Optional[str]:
30+
if string == "None":
31+
string = None
32+
return string
33+
2634

2735
def get_info_from_hub_resource_arn(
2836
arn: str,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745
hub_name = match.group(4)
3846
hub_content_type = match.group(5)
3947
hub_content_name = match.group(6)
40-
hub_content_version = match.group(7)
48+
hub_content_version = _convert_str_to_optional(match.group(7))
4149

4250
return HubArnExtractedInfo(
4351
partition=partition,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202
hub_model_version: Optional[str] = None,
195203
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
196204
) -> str:
197-
"""Returns available Jumpstart hub model version
205+
"""Returns available Jumpstart hub model version.
206+
207+
It will attempt both a semantic HubContent version search and Marketplace version search.
208+
If the Marketplace version is also semantic, this function will default to HubContent version.
198209
199210
Raises:
200211
ClientError: If the specified model is not found in the hub.
212+
KeyError: If the specified model version is not found.
201213
"""
202214

203215
try:
@@ -207,6 +219,23 @@ def get_hub_model_version(
207219
except Exception as ex:
208220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
209221

222+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
223+
hub_content_summaries, hub_model_version
224+
)
225+
226+
try:
227+
return _get_hub_model_version_for_open_weight_version(
228+
hub_content_summaries, hub_model_version
229+
)
230+
except KeyError as e:
231+
if marketplace_hub_content_version:
232+
return marketplace_hub_content_version
233+
raise e
234+
235+
236+
def _get_hub_model_version_for_open_weight_version(
237+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
238+
) -> str:
210239
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
211240

212241
if hub_model_version == "*" or hub_model_version is None:
@@ -222,3 +251,37 @@ def get_hub_model_version(
222251
hub_model_version = str(max(available_versions_filtered))
223252

224253
return hub_model_version
254+
255+
256+
def _get_hub_model_version_for_marketplace_version(
257+
hub_content_summaries: List[Any], marketplace_version: str
258+
) -> Optional[str]:
259+
"""Returns the HubContent version associated with the Marketplace version.
260+
261+
This function will check within the HubContentSearchKeywords for the proprietary version.
262+
"""
263+
for model in hub_content_summaries:
264+
model_search_keywords = model.get("HubContentSearchKeywords", [])
265+
if _hub_search_keywords_contains_marketplace_version(
266+
model_search_keywords, marketplace_version
267+
):
268+
return model.get("HubContentVersion")
269+
270+
return None
271+
272+
273+
def _hub_search_keywords_contains_marketplace_version(
274+
model_search_keywords: List[str], marketplace_version: str
275+
) -> bool:
276+
proprietary_version_keyword = next(
277+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
278+
)
279+
280+
if not proprietary_version_keyword:
281+
return False
282+
283+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
284+
if proprietary_version == marketplace_version:
285+
return True
286+
287+
return False

0 commit comments

Comments
 (0)