Skip to content

Commit 2331dec

Browse files
Captainiaakozd
andauthored
fix: update passing additional model data sources to API (aws#1472)
* feat: Added utils for extracting JS data sources (aws#1471) * added utils for accessing hosting data sources * added utils for accessing hosting data sources * removed other changes * fixed formatting issues * remove .keys() * updated JumpStartModelDataSource * fix slots * remove print * fix tests * update tests * fix: update passing additional model data sources to API * format * format * format * format and address comments * format * format * format --------- Co-authored-by: Adam Kozdrowicz <[email protected]>
1 parent 9a410e5 commit 2331dec

File tree

6 files changed

+134
-21
lines changed

6 files changed

+134
-21
lines changed

src/sagemaker/jumpstart/factory/model.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
5959
from sagemaker.session import Session
60-
from sagemaker.utils import name_from_base, format_tags, Tags
60+
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
6161
from sagemaker.workflow.entities import PipelineVariable
6262
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
6363
from sagemaker import resource_requirements
@@ -615,6 +615,40 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
615615
return kwargs
616616

617617

618+
def _add_additional_model_data_sources_to_kwargs(
619+
kwargs: JumpStartModelInitKwargs,
620+
) -> JumpStartModelInitKwargs:
621+
"""Sets default additional model data sources to init kwargs"""
622+
623+
specs = verify_model_region_and_return_specs(
624+
model_id=kwargs.model_id,
625+
version=kwargs.model_version,
626+
scope=JumpStartScriptScope.INFERENCE,
627+
region=kwargs.region,
628+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
629+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
630+
sagemaker_session=kwargs.sagemaker_session,
631+
model_type=kwargs.model_type,
632+
config_name=kwargs.config_name,
633+
)
634+
635+
additional_data_sources = specs.get_additional_s3_data_sources()
636+
api_shape_additional_model_data_sources = (
637+
[
638+
camel_case_to_pascal_case(data_source.to_json())
639+
for data_source in additional_data_sources
640+
]
641+
if specs.get_additional_s3_data_sources()
642+
else None
643+
)
644+
645+
kwargs.additional_model_data_sources = (
646+
kwargs.additional_model_data_sources or api_shape_additional_model_data_sources
647+
)
648+
649+
return kwargs
650+
651+
618652
def _add_config_name_to_deploy_kwargs(
619653
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
620654
) -> JumpStartModelInitKwargs:
@@ -861,6 +895,7 @@ def get_init_kwargs(
861895
disable_instance_type_logging: bool = False,
862896
resources: Optional[ResourceRequirements] = None,
863897
config_name: Optional[str] = None,
898+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
864899
) -> JumpStartModelInitKwargs:
865900
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
866901

@@ -893,6 +928,7 @@ def get_init_kwargs(
893928
training_instance_type=training_instance_type,
894929
resources=resources,
895930
config_name=config_name,
931+
additional_model_data_sources=additional_model_data_sources,
896932
)
897933

898934
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
@@ -925,4 +961,6 @@ def get_init_kwargs(
925961

926962
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
927963

964+
model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
965+
928966
return model_init_kwargs

src/sagemaker/jumpstart/model.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
model_package_arn: Optional[str] = None,
103103
resources: Optional[ResourceRequirements] = None,
104104
config_name: Optional[str] = None,
105+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
105106
):
106107
"""Initializes a ``JumpStartModel``.
107108
@@ -287,8 +288,10 @@ def __init__(
287288
for a model to be deployed to an endpoint.
288289
Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
289290
(Default: None).
290-
config_name (Optional[str]): The name of the JumpStartConfig that can be
291-
optionally applied to the model and override corresponding fields.
291+
config_name (Optional[str]): The name of the JumpStart config that can be
292+
optionally applied to the model.
293+
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
294+
of SageMaker model data (default: None).
292295
Raises:
293296
ValueError: If the model ID is not recognized by JumpStart.
294297
"""
@@ -339,6 +342,7 @@ def _validate_model_id_and_type():
339342
model_package_arn=model_package_arn,
340343
resources=resources,
341344
config_name=config_name,
345+
additional_model_data_sources=additional_model_data_sources,
342346
)
343347

344348
self.orig_predictor_cls = predictor_cls
@@ -352,6 +356,7 @@ def _validate_model_id_and_type():
352356
self.region = model_init_kwargs.region
353357
self.sagemaker_session = model_init_kwargs.sagemaker_session
354358
self.config_name = model_init_kwargs.config_name
359+
self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources
355360

356361
if self.model_type == JumpStartModelType.PROPRIETARY:
357362
self.log_subscription_warning()
@@ -369,14 +374,6 @@ def _validate_model_id_and_type():
369374
model_type=self.model_type,
370375
)
371376

372-
self.additional_model_data_sources = (
373-
self._metadata_configs.get(self.config_name).resolved_config.get(
374-
"hosting_additional_data_sources"
375-
)
376-
if self._metadata_configs.get(self.config_name)
377-
else None
378-
)
379-
380377
def log_subscription_warning(self) -> None:
381378
"""Log message prompting the customer to subscribe to the proprietary model."""
382379
subscription_link = verify_model_region_and_return_specs(

src/sagemaker/jumpstart/types.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -844,13 +844,15 @@ def to_json(self) -> Dict[str, Any]:
844844
cur_val = getattr(self, att)
845845
if issubclass(type(cur_val), JumpStartDataHolderType):
846846
json_obj[att] = cur_val.to_json()
847-
else:
847+
elif cur_val:
848848
json_obj[att] = cur_val
849849
return json_obj
850850

851851

852852
class AdditionalModelDataSource(JumpStartDataHolderType):
853-
"""Data class of additional model data source mirrors Hosting API."""
853+
"""Data class of additional model data source mirrors CreateModel API."""
854+
855+
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
854856

855857
__slots__ = ["channel_name", "s3_data_source"]
856858

@@ -871,23 +873,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
871873
self.channel_name: str = json_obj["channel_name"]
872874
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
873875

874-
def to_json(self) -> Dict[str, Any]:
876+
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
875877
"""Returns json representation of AdditionalModelDataSource object."""
876878
json_obj = {}
877879
for att in self.__slots__:
878880
if hasattr(self, att):
879-
cur_val = getattr(self, att)
880-
if issubclass(type(cur_val), JumpStartDataHolderType):
881-
json_obj[att] = cur_val.to_json()
882-
else:
883-
json_obj[att] = cur_val
881+
if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys:
882+
cur_val = getattr(self, att)
883+
if issubclass(type(cur_val), JumpStartDataHolderType):
884+
json_obj[att] = cur_val.to_json()
885+
else:
886+
json_obj[att] = cur_val
884887
return json_obj
885888

886889

887890
class JumpStartModelDataSource(AdditionalModelDataSource):
888891
"""Data class JumpStart additional model data source."""
889892

890-
__slots__ = ["artifact_version"] + AdditionalModelDataSource.__slots__
893+
SERIALIZATION_EXCLUSION_SET = {"artifact_version"}
894+
895+
__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
891896

892897
def from_json(self, json_obj: Dict[str, Any]) -> None:
893898
"""Sets fields in object based on json.
@@ -1761,6 +1766,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
17611766
"training_instance_type",
17621767
"resources",
17631768
"config_name",
1769+
"additional_model_data_sources",
17641770
]
17651771

17661772
SERIALIZATION_EXCLUSION_SET = {
@@ -1806,6 +1812,7 @@ def __init__(
18061812
training_instance_type: Optional[str] = None,
18071813
resources: Optional[ResourceRequirements] = None,
18081814
config_name: Optional[str] = None,
1815+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
18091816
) -> None:
18101817
"""Instantiates JumpStartModelInitKwargs object."""
18111818

@@ -1837,6 +1844,7 @@ def __init__(
18371844
self.training_instance_type = training_instance_type
18381845
self.resources = resources
18391846
self.config_name = config_name
1847+
self.additional_model_data_sources = additional_model_data_sources
18401848

18411849

18421850
class JumpStartModelDeployKwargs(JumpStartKwargs):

src/sagemaker/utils.py

+30
Original file line numberDiff line numberDiff line change
@@ -1798,3 +1798,33 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[
17981798
"name": "Instance Rate",
17991799
}
18001800
return None
1801+
1802+
1803+
def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
1804+
"""Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
1805+
1806+
Args:
1807+
data (dict): The dictionary to be updated.
1808+
1809+
Returns:
1810+
dict: The updated dictionary with keys in PascalCase.
1811+
"""
1812+
result = {}
1813+
1814+
def convert_key(key):
1815+
"""Converts a snake_case key to PascalCase."""
1816+
return "".join(part.capitalize() for part in key.split("_"))
1817+
1818+
def convert_value(value):
1819+
"""Recursively processes the value of a key-value pair."""
1820+
if isinstance(value, dict):
1821+
return camel_case_to_pascal_case(value)
1822+
if isinstance(value, list):
1823+
return [convert_value(item) for item in value]
1824+
1825+
return value
1826+
1827+
for key, value in data.items():
1828+
result[convert_key(key)] = convert_value(value)
1829+
1830+
return result

tests/unit/sagemaker/jumpstart/model/test_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
759759
Please add the new argument to the skip set below,
760760
and reach out to JumpStart team."""
761761

762-
init_args_to_skip: Set[str] = set(["additional_model_data_sources"])
762+
init_args_to_skip: Set[str] = set([])
763763
deploy_args_to_skip: Set[str] = set(["kwargs"])
764764

765765
parent_class_init = Model.__init__

tests/unit/test_utils.py

+40
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.experiments._run_context import _RunContext
3535
from sagemaker.session_settings import SessionSettings
3636
from sagemaker.utils import (
37+
camel_case_to_pascal_case,
3738
deep_override_dict,
3839
flatten_dict,
3940
get_instance_type_family,
@@ -2055,3 +2056,42 @@ def test_resolve_routing_config(routing_config, expected):
20552056

20562057
def test_resolve_routing_config_ex():
20572058
pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"}))
2059+
2060+
2061+
class TestConvertToPascalCase(TestCase):
2062+
def test_simple_dict(self):
2063+
input_dict = {"first_name": "John", "last_name": "Doe"}
2064+
expected_output = {"FirstName": "John", "LastName": "Doe"}
2065+
self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output)
2066+
2067+
def camel_case_to_pascal_case_nested(self):
2068+
input_dict = {
2069+
"model_name": "my-model",
2070+
"primary_container": {
2071+
"image": "my-docker-image:latest",
2072+
"model_data_url": "s3://my-bucket/model.tar.gz",
2073+
"environment": {"env_var_1": "value1", "env_var_2": "value2"},
2074+
},
2075+
"execution_role_arn": "arn:aws:iam::123456789012:role/my-sagemaker-role",
2076+
"tags": [
2077+
{"key": "project", "value": "my-project"},
2078+
{"key": "environment", "value": "development"},
2079+
],
2080+
}
2081+
expected_output = {
2082+
"ModelName": "my-model",
2083+
"PrimaryContainer": {
2084+
"Image": "my-docker-image:latest",
2085+
"ModelDataUrl": "s3://my-bucket/model.tar.gz",
2086+
"Environment": {"EnvVar1": "value1", "EnvVar2": "value2"},
2087+
},
2088+
"ExecutionRoleArn": "arn:aws:iam::123456789012:role/my-sagemaker-role",
2089+
"Tags": [
2090+
{"Key": "project", "Value": "my-project"},
2091+
{"Key": "environment", "Value": "development"},
2092+
],
2093+
}
2094+
self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output)
2095+
2096+
def test_empty_input(self):
2097+
self.assertEqual(camel_case_to_pascal_case({}), {})

0 commit comments

Comments
 (0)