Skip to content

Commit f1bc99e

Browse files
authored
feat: Support Alt Configs for Public & Curated Hub (aws#1505)
* feat: add alt config support for public & curated hub
1 parent 31d70e6 commit f1bc99e

File tree

6 files changed

+246
-12
lines changed

6 files changed

+246
-12
lines changed

src/sagemaker/jumpstart/hub/interfaces.py

+91
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
"""This module stores types related to SageMaker JumpStart HubAPI requests and responses."""
1414
from __future__ import absolute_import
1515

16+
from enum import Enum
1617
import re
1718
import json
1819
import datetime
1920

2021
from typing import Any, Dict, List, Union, Optional
22+
from sagemaker.jumpstart.enums import JumpStartScriptScope
2123
from sagemaker.jumpstart.types import (
2224
HubContentType,
2325
HubArnExtractedInfo,
26+
JumpStartConfigComponent,
27+
JumpStartConfigRanking,
28+
JumpStartMetadataConfig,
29+
JumpStartMetadataConfigs,
2430
JumpStartPredictorSpecs,
2531
JumpStartHyperparameter,
2632
JumpStartDataHolderType,
@@ -34,6 +40,13 @@
3440
)
3541

3642

43+
class _ComponentType(str, Enum):
44+
"""Enum for different component types."""
45+
46+
INFERENCE = "Inference"
47+
TRAINING = "Training"
48+
49+
3750
class HubDataHolderType(JumpStartDataHolderType):
3851
"""Base class for many Hub API interfaces."""
3952

@@ -456,6 +469,9 @@ class HubModelDocument(HubDataHolderType):
456469
"hosting_use_script_uri",
457470
"hosting_eula_uri",
458471
"hosting_model_package_arn",
472+
"inference_configs",
473+
"inference_config_components",
474+
"inference_config_rankings",
459475
"training_artifact_s3_data_type",
460476
"training_artifact_compression_type",
461477
"training_model_package_artifact_uri",
@@ -467,6 +483,9 @@ class HubModelDocument(HubDataHolderType):
467483
"training_ecr_uri",
468484
"training_metrics",
469485
"training_artifact_uri",
486+
"training_configs",
487+
"training_config_components",
488+
"training_config_rankings",
470489
"inference_dependencies",
471490
"training_dependencies",
472491
"default_inference_instance_type",
@@ -566,6 +585,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
566585
)
567586
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
568587
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
588+
589+
self.inference_config_rankings = self._get_config_rankings(json_obj)
590+
self.inference_config_components = self._get_config_components(json_obj)
591+
self.inference_configs = self._get_configs(json_obj)
592+
569593
self.default_inference_instance_type: Optional[str] = json_obj.get(
570594
"DefaultInferenceInstanceType"
571595
)
@@ -667,6 +691,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
667691
"TrainingMetrics", None
668692
)
669693
self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri")
694+
695+
self.training_config_rankings = self._get_config_rankings(
696+
json_obj, _ComponentType.TRAINING
697+
)
698+
self.training_config_components = self._get_config_components(
699+
json_obj, _ComponentType.TRAINING
700+
)
701+
self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING)
702+
670703
self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies")
671704
self.default_training_instance_type: Optional[str] = json_obj.get(
672705
"DefaultTrainingInstanceType"
@@ -707,6 +740,64 @@ def get_region(self) -> str:
707740
"""Returns hub region."""
708741
return self._region
709742

743+
def _get_config_rankings(
744+
self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
745+
) -> Optional[Dict[str, JumpStartConfigRanking]]:
746+
"""Returns config rankings."""
747+
config_rankings = json_obj.get(f"{component_type.value}ConfigRankings")
748+
return (
749+
{
750+
alias: JumpStartConfigRanking(ranking, is_hub_content=True)
751+
for alias, ranking in config_rankings.items()
752+
}
753+
if config_rankings
754+
else None
755+
)
756+
757+
def _get_config_components(
758+
self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
759+
) -> Optional[Dict[str, JumpStartConfigComponent]]:
760+
"""Returns config components."""
761+
config_components = json_obj.get(f"{component_type.value}ConfigComponents")
762+
return (
763+
{
764+
alias: JumpStartConfigComponent(alias, config, is_hub_content=True)
765+
for alias, config in config_components.items()
766+
}
767+
if config_components
768+
else None
769+
)
770+
771+
def _get_configs(
772+
self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
773+
) -> Optional[JumpStartMetadataConfigs]:
774+
"""Returns configs."""
775+
if not (configs := json_obj.get(f"{component_type.value}Configs")):
776+
return None
777+
778+
configs_dict = {}
779+
for alias, config in configs.items():
780+
config_components = None
781+
if isinstance(config, dict) and (component_names := config.get("ComponentNames")):
782+
config_components = {
783+
name: getattr(self, f"{component_type.value.lower()}_config_components").get(
784+
name
785+
)
786+
for name in component_names
787+
}
788+
configs_dict[alias] = JumpStartMetadataConfig(
789+
alias, config, json_obj, config_components, is_hub_content=True
790+
)
791+
792+
if component_type == _ComponentType.INFERENCE:
793+
config_rankings = self.inference_config_rankings
794+
scope = JumpStartScriptScope.INFERENCE
795+
else:
796+
config_rankings = self.training_config_rankings
797+
scope = JumpStartScriptScope.TRAINING
798+
799+
return JumpStartMetadataConfigs(configs_dict, config_rankings, scope)
800+
710801

711802
class HubNotebookDocument(HubDataHolderType):
712803
"""Data class for notebook type HubContentDocument from session.describe_hub_content()."""

src/sagemaker/jumpstart/hub/parsers.py

+8
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def make_model_specs_from_describe_hub_content_response(
142142
hub_model_document.incremental_training_supported
143143
)
144144
specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri
145+
specs["inference_configs"] = hub_model_document.inference_configs
146+
specs["inference_config_components"] = hub_model_document.inference_config_components
147+
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
145148

146149
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
147150
hub_model_document.hosting_artifact_uri
@@ -233,6 +236,11 @@ def make_model_specs_from_describe_hub_content_response(
233236
training_script_key,
234237
) = parse_s3_url(hub_model_document.training_script_uri)
235238
specs["training_script_key"] = training_script_key
239+
240+
specs["training_configs"] = hub_model_document.training_configs
241+
specs["training_config_components"] = hub_model_document.training_config_components
242+
specs["training_config_rankings"] = hub_model_document.training_config_rankings
243+
236244
specs["training_dependencies"] = hub_model_document.training_dependencies
237245
specs["default_training_instance_type"] = hub_model_document.default_training_instance_type
238246
specs["supported_training_instance_types"] = (

src/sagemaker/jumpstart/types.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1169,12 +1169,14 @@ class JumpStartConfigRanking(JumpStartDataHolderType):
11691169

11701170
__slots__ = ["description", "rankings"]
11711171

1172-
def __init__(self, spec: Optional[Dict[str, Any]]):
1172+
def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
11731173
"""Initializes a JumpStartConfigRanking object.
11741174
11751175
Args:
11761176
spec (Dict[str, Any]): Dictionary representation of training config ranking.
11771177
"""
1178+
if is_hub_content:
1179+
spec = {camel_to_snake(key): val for key, val in spec.items()}
11781180
self.from_json(spec)
11791181

11801182
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1285,7 +1287,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12851287
json_obj.get("incremental_training_supported", False)
12861288
)
12871289
if self._is_hub_content:
1288-
self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"]
1290+
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
12891291
self._non_serializable_slots.append("hosting_ecr_specs")
12901292
else:
12911293
self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
@@ -1491,9 +1493,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
14911493
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
14921494

14931495
def __init__(
1494-
self,
1495-
component_name: str,
1496-
component: Optional[Dict[str, Any]],
1496+
self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False
14971497
):
14981498
"""Initializes a JumpStartConfigComponent object from its json representation.
14991499
@@ -1504,8 +1504,10 @@ def __init__(
15041504
Raises:
15051505
ValueError: If the component field is invalid.
15061506
"""
1507-
super().__init__(component)
1507+
if is_hub_content:
1508+
component = walk_and_apply_json(component, camel_to_snake)
15081509
self.component_name = component_name
1510+
super().__init__(component, is_hub_content)
15091511
self.from_json(component)
15101512

15111513
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1542,6 +1544,7 @@ def __init__(
15421544
config: Dict[str, Any],
15431545
base_fields: Dict[str, Any],
15441546
config_components: Dict[str, JumpStartConfigComponent],
1547+
is_hub_content=False,
15451548
):
15461549
"""Initializes a JumpStartMetadataConfig object from its json representation.
15471550
@@ -1554,6 +1557,9 @@ def __init__(
15541557
config_components (Dict[str, JumpStartConfigComponent]):
15551558
The list of components that are used to construct the resolved config.
15561559
"""
1560+
if is_hub_content:
1561+
config = walk_and_apply_json(config, camel_to_snake)
1562+
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
15571563
self.base_fields = base_fields
15581564
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
15591565
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (

tests/unit/sagemaker/jumpstart/constants.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -8703,7 +8703,17 @@
87038703
"variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}},
87048704
},
87058705
},
8706-
"neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}},
8706+
"neuron-budget": {
8707+
"inference_environment_variables": [
8708+
{
8709+
"name": "SAGEMAKER_PROGRAM",
8710+
"type": "text",
8711+
"default": "inference.py",
8712+
"scope": "container",
8713+
"required_for_model_class": True,
8714+
}
8715+
],
8716+
},
87078717
"gpu-inference": {
87088718
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
87098719
"hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/",
@@ -9816,6 +9826,58 @@
98169826
"DynamicContainerDeploymentSupported": True,
98179827
"TrainingModelPackageArtifactUri": None,
98189828
"Dependencies": [],
9829+
"InferenceConfigRankings": {
9830+
"overall": {"Description": "default", "Rankings": ["variant1"]}
9831+
},
9832+
"InferenceConfigs": {
9833+
"variant1": {
9834+
"ComponentNames": ["variant1"],
9835+
"BenchmarkMetrics": {
9836+
"ml.g5.12xlarge": [
9837+
{"Name": "latency", "Unit": "sec", "Value": "0.19", "Concurrency": "1"},
9838+
]
9839+
},
9840+
},
9841+
},
9842+
"InferenceConfigComponents": {
9843+
"variant1": {
9844+
"HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository",
9845+
"HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501
9846+
"HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501
9847+
"InferenceDependencies": [],
9848+
"InferenceEnvironmentVariables": [
9849+
{
9850+
"Name": "SAGEMAKER_PROGRAM",
9851+
"Type": "text",
9852+
"Default": "inference.py",
9853+
"Scope": "container",
9854+
"RequiredForModelClass": True,
9855+
}
9856+
],
9857+
"HostingAdditionalDataSources": {
9858+
"speculative_decoding": [
9859+
{
9860+
"ArtifactVersion": 1,
9861+
"ChannelName": "speculative_decoding_channel_1",
9862+
"S3DataSource": {
9863+
"CompressionType": "None",
9864+
"S3DataType": "S3Prefix",
9865+
"S3Uri": "s3://bucket/path/1",
9866+
},
9867+
},
9868+
{
9869+
"ArtifactVersion": 1,
9870+
"ChannelName": "speculative_decoding_channel_2",
9871+
"S3DataSource": {
9872+
"CompressionType": "None",
9873+
"S3DataType": "S3Prefix",
9874+
"S3Uri": "s3://bucket/path/2",
9875+
},
9876+
},
9877+
]
9878+
},
9879+
},
9880+
},
98199881
},
98209882
"meta-textgeneration-llama-2-70b": {
98219883
"Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/",

0 commit comments

Comments
 (0)