Skip to content

Commit c4529e3

Browse files
authored
feat: additional hosting model data source parsing (aws#1467)
* feat: Additional Model Data source parsing * address comments * address comments * format
1 parent 73bf439 commit c4529e3

File tree

6 files changed

+372
-0
lines changed

6 files changed

+372
-0
lines changed

src/sagemaker/jumpstart/types.py

+235
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,235 @@ def _get_regional_property(
743743
return alias_value
744744

745745

746+
class ModelAccessConfig(JumpStartDataHolderType):
747+
"""Data class of model access config that mirrors CreateModel API."""
748+
749+
__slots__ = ["accept_eula"]
750+
751+
def __init__(self, spec: Dict[str, Any]):
752+
"""Initializes a ModelAccessConfig object.
753+
754+
Args:
755+
spec (Dict[str, Any]): Dictionary representation of data source.
756+
"""
757+
self.from_json(spec)
758+
759+
def from_json(self, json_obj: Dict[str, Any]) -> None:
760+
"""Sets fields in object based on json.
761+
762+
Args:
763+
json_obj (Dict[str, Any]): Dictionary representation of data source.
764+
"""
765+
self.accept_eula: bool = json_obj["accept_eula"]
766+
767+
def to_json(self) -> Dict[str, Any]:
768+
"""Returns json representation of ModelAccessConfig object."""
769+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
770+
return json_obj
771+
772+
773+
class HubAccessConfig(JumpStartDataHolderType):
774+
"""Data class of model access config that mirrors CreateModel API."""
775+
776+
__slots__ = ["hub_content_arn"]
777+
778+
def __init__(self, spec: Dict[str, Any]):
779+
"""Initializes a HubAccessConfig object.
780+
781+
Args:
782+
spec (Dict[str, Any]): Dictionary representation of data source.
783+
"""
784+
self.from_json(spec)
785+
786+
def from_json(self, json_obj: Dict[str, Any]) -> None:
787+
"""Sets fields in object based on json.
788+
789+
Args:
790+
json_obj (Dict[str, Any]): Dictionary representation of data source.
791+
"""
792+
self.hub_content_arn: bool = json_obj["accept_eula"]
793+
794+
def to_json(self) -> Dict[str, Any]:
795+
"""Returns json representation of ModelAccessConfig object."""
796+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
797+
return json_obj
798+
799+
800+
class S3DataSource(JumpStartDataHolderType):
801+
"""Data class of S3 data source that mirrors CreateModel API."""
802+
803+
__slots__ = [
804+
"compression_type",
805+
"s3_data_type",
806+
"s3_uri",
807+
"model_access_config",
808+
"hub_access_config",
809+
]
810+
811+
def __init__(self, spec: Dict[str, Any]):
812+
"""Initializes a S3DataSource object.
813+
814+
Args:
815+
spec (Dict[str, Any]): Dictionary representation of data source.
816+
"""
817+
self.from_json(spec)
818+
819+
def from_json(self, json_obj: Dict[str, Any]) -> None:
820+
"""Sets fields in object based on json.
821+
822+
Args:
823+
json_obj (Dict[str, Any]): Dictionary representation of data source.
824+
"""
825+
self.compression_type: str = json_obj["compression_type"]
826+
self.s3_data_type: str = json_obj["s3_data_type"]
827+
self.s3_uri: str = json_obj["s3_uri"]
828+
self.model_access_config: ModelAccessConfig = (
829+
ModelAccessConfig(json_obj["model_access_config"])
830+
if json_obj.get("model_access_config")
831+
else None
832+
)
833+
self.hub_access_config: HubAccessConfig = (
834+
HubAccessConfig(json_obj["hub_access_config"])
835+
if json_obj.get("hub_access_config")
836+
else None
837+
)
838+
839+
def to_json(self) -> Dict[str, Any]:
840+
"""Returns json representation of S3DataSource object."""
841+
json_obj = {}
842+
for att in self.__slots__:
843+
if hasattr(self, att):
844+
cur_val = getattr(self, att)
845+
if issubclass(type(cur_val), JumpStartDataHolderType):
846+
json_obj[att] = cur_val.to_json()
847+
else:
848+
json_obj[att] = cur_val
849+
return json_obj
850+
851+
852+
class AdditionalModelDataSource(JumpStartDataHolderType):
853+
"""Data class of additional model data source mirrors Hosting API."""
854+
855+
__slots__ = ["channel_name", "s3_data_source"]
856+
857+
def __init__(self, spec: Dict[str, Any]):
858+
"""Initializes a AdditionalModelDataSource object.
859+
860+
Args:
861+
spec (Dict[str, Any]): Dictionary representation of data source.
862+
"""
863+
self.from_json(spec)
864+
865+
def from_json(self, json_obj: Dict[str, Any]) -> None:
866+
"""Sets fields in object based on json.
867+
868+
Args:
869+
json_obj (Dict[str, Any]): Dictionary representation of data source.
870+
"""
871+
self.channel_name: str = json_obj["channel_name"]
872+
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
873+
874+
def to_json(self) -> Dict[str, Any]:
875+
"""Returns json representation of AdditionalModelDataSource object."""
876+
json_obj = {}
877+
for att in self.__slots__:
878+
if hasattr(self, att):
879+
cur_val = getattr(self, att)
880+
if issubclass(type(cur_val), JumpStartDataHolderType):
881+
json_obj[att] = cur_val.to_json()
882+
else:
883+
json_obj[att] = cur_val
884+
return json_obj
885+
886+
887+
class JumpStartModelDataSource(JumpStartDataHolderType):
888+
"""Data class JumpStart additional model data source."""
889+
890+
__slots__ = ["version", "additional_model_data_source"]
891+
892+
def __init__(self, spec: Dict[str, Any]):
893+
"""Initializes a JumpStartModelDataSource object.
894+
895+
Args:
896+
spec (Dict[str, Any]): Dictionary representation of data source.
897+
"""
898+
self.from_json(spec)
899+
900+
def from_json(self, json_obj: Dict[str, Any]) -> None:
901+
"""Sets fields in object based on json.
902+
903+
Args:
904+
json_obj (Dict[str, Any]): Dictionary representation of data source.
905+
"""
906+
self.version: str = json_obj["artifact_version"]
907+
self.additional_model_data_source: AdditionalModelDataSource = AdditionalModelDataSource(
908+
json_obj
909+
)
910+
911+
def to_json(self) -> Dict[str, Any]:
912+
"""Returns json representation of JumpStartModelDataSource object."""
913+
json_obj = {}
914+
for att in self.__slots__:
915+
if hasattr(self, att):
916+
cur_val = getattr(self, att)
917+
if issubclass(type(cur_val), JumpStartDataHolderType):
918+
json_obj[att] = cur_val.to_json()
919+
else:
920+
json_obj[att] = cur_val
921+
return json_obj
922+
923+
924+
class JumpStartAdditionalDataSources(JumpStartDataHolderType):
925+
"""Data class of additional data sources."""
926+
927+
__slots__ = ["speculative_decoding", "scripts"]
928+
929+
def __init__(self, spec: Dict[str, Any]):
930+
"""Initializes a AdditionalDataSources object.
931+
932+
Args:
933+
spec (Dict[str, Any]): Dictionary representation of data source.
934+
"""
935+
self.from_json(spec)
936+
937+
def from_json(self, json_obj: Dict[str, Any]) -> None:
938+
"""Sets fields in object based on json.
939+
940+
Args:
941+
json_obj (Dict[str, Any]): Dictionary representation of data source.
942+
"""
943+
self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = (
944+
[
945+
JumpStartModelDataSource(data_source)
946+
for data_source in json_obj["speculative_decoding"]
947+
]
948+
if json_obj.get("speculative_decoding")
949+
else None
950+
)
951+
self.scripts: Optional[List[JumpStartModelDataSource]] = (
952+
[JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]]
953+
if json_obj.get("scripts")
954+
else None
955+
)
956+
957+
def to_json(self) -> Dict[str, Any]:
958+
"""Returns json representation of AdditionalDataSources object."""
959+
json_obj = {}
960+
for att in self.__slots__:
961+
if hasattr(self, att):
962+
cur_val = getattr(self, att)
963+
if isinstance(cur_val, list):
964+
json_obj[att] = []
965+
for obj in cur_val:
966+
if issubclass(type(obj), JumpStartDataHolderType):
967+
json_obj[att].append(obj.to_json())
968+
else:
969+
json_obj[att].append(obj)
970+
else:
971+
json_obj[att] = cur_val
972+
return json_obj
973+
974+
746975
class JumpStartBenchmarkStat(JumpStartDataHolderType):
747976
"""Data class JumpStart benchmark stat."""
748977

@@ -857,6 +1086,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
8571086
"default_payloads",
8581087
"gated_bucket",
8591088
"model_subscription_link",
1089+
"hosting_additional_data_sources",
8601090
]
8611091

8621092
def __init__(self, fields: Dict[str, Any]):
@@ -962,6 +1192,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
9621192
if json_obj.get("hosting_instance_type_variants")
9631193
else None
9641194
)
1195+
self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = (
1196+
JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"])
1197+
if json_obj.get("hosting_additional_data_sources")
1198+
else None
1199+
)
9651200

9661201
if self.training_supported:
9671202
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (

src/sagemaker/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,7 @@ def deep_override_dict(
16901690
skip_keys = []
16911691

16921692
flattened_dict1 = flatten_dict(dict1)
1693+
flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None}
16931694
flattened_dict2 = flatten_dict(
16941695
{key: value for key, value in dict2.items() if key not in skip_keys}
16951696
)

tests/unit/sagemaker/jumpstart/constants.py

+38
Original file line numberDiff line numberDiff line change
@@ -7515,6 +7515,7 @@
75157515
"training_config_components": None,
75167516
"inference_config_rankings": None,
75177517
"training_config_rankings": None,
7518+
"hosting_additional_data_sources": None,
75187519
}
75197520

75207521
BASE_HEADER = {
@@ -7700,6 +7701,14 @@
77007701
},
77017702
"component_names": ["gpu-inference-model-package"],
77027703
},
7704+
"gpu-accelerated": {
7705+
"benchmark_metrics": {
7706+
"ml.p3.2xlarge": [
7707+
{"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
7708+
]
7709+
},
7710+
"component_names": ["gpu-accelerated"],
7711+
},
77037712
},
77047713
"inference_config_components": {
77057714
"neuron-base": {
@@ -7765,6 +7774,34 @@
77657774
},
77667775
},
77677776
},
7777+
"gpu-accelerated": {
7778+
"hosting_instance_type_variants": {
7779+
"regional_aliases": {
7780+
"us-west-2": {
7781+
"gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
7782+
"pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04"
7783+
}
7784+
},
7785+
"variants": {
7786+
"p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}},
7787+
"p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}},
7788+
},
7789+
},
7790+
"hosting_additional_data_sources": {
7791+
"speculative_decoding": [
7792+
{
7793+
"channel_name": "draft_model_name",
7794+
"artifact_version": "1.2.1",
7795+
"s3_data_source": {
7796+
"compression_type": "None",
7797+
"model_access_config": {"accept_eula": False},
7798+
"s3_data_type": "S3Prefix",
7799+
"s3_uri": "key/to/draft/model/artifact/",
7800+
},
7801+
}
7802+
],
7803+
},
7804+
},
77687805
},
77697806
}
77707807

@@ -7907,6 +7944,7 @@
79077944
"neuron-inference-budget",
79087945
"gpu-inference",
79097946
"gpu-inference-budget",
7947+
"gpu-accelerated",
79107948
],
79117949
},
79127950
"performance": {

0 commit comments

Comments
 (0)