Skip to content

Commit 73bf439

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
feat: Model class to support AdditionalModelDataSources (aws#1469)
* Add support for AdditionalModelDataSources * Resolve PR comments * Resolve PR comments * Resolve PR comments * fix unit tests * Resolve PR comments --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 2c3d606 commit 73bf439

File tree

5 files changed

+21
-1
lines changed

5 files changed

+21
-1
lines changed

src/sagemaker/jumpstart/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,14 @@ def _validate_model_id_and_type():
369369
model_type=self.model_type,
370370
)
371371

372+
self.additional_model_data_sources = (
373+
self._metadata_configs.get(self.config_name).resolved_config.get(
374+
"hosting_additional_data_sources"
375+
)
376+
if self._metadata_configs.get(self.config_name)
377+
else None
378+
)
379+
372380
def log_subscription_warning(self) -> None:
373381
"""Log message prompting the customer to subscribe to the proprietary model."""
374382
subscription_link = verify_model_region_and_return_specs(

src/sagemaker/model.py

+5
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
dependencies: Optional[List[str]] = None,
161161
git_config: Optional[Dict[str, str]] = None,
162162
resources: Optional[ResourceRequirements] = None,
163+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
163164
):
164165
"""Initialize an SageMaker ``Model``.
165166
@@ -323,9 +324,12 @@ def __init__(
323324
for a model to be deployed to an endpoint. Only
324325
EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
325326
(Default: None).
327+
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
328+
of SageMaker model data (default: None).
326329
327330
"""
328331
self.model_data = model_data
332+
self.additional_model_data_sources = additional_model_data_sources
329333
self.image_uri = image_uri
330334
self.predictor_cls = predictor_cls
331335
self.name = name
@@ -671,6 +675,7 @@ def prepare_container_def(
671675
accept_eula=(
672676
accept_eula if accept_eula is not None else getattr(self, "accept_eula", None)
673677
),
678+
additional_model_data_sources=self.additional_model_data_sources,
674679
)
675680

676681
def is_repack(self) -> bool:

src/sagemaker/session.py

+6
Original file line numberDiff line numberDiff line change
@@ -7137,6 +7137,7 @@ def container_def(
71377137
container_mode=None,
71387138
image_config=None,
71397139
accept_eula=None,
7140+
additional_model_data_sources=None,
71407141
):
71417142
"""Create a definition for executing a container as part of a SageMaker model.
71427143
@@ -7159,6 +7160,8 @@ def container_def(
71597160
The `accept_eula` value must be explicitly defined as `True` in order to
71607161
accept the end-user license agreement (EULA) that some
71617162
models require. (Default: None).
7163+
additional_model_data_sources (PipelineVariable or dict): Additional location
7164+
of SageMaker model data (default: None).
71627165
71637166
Returns:
71647167
dict[str, str]: A complete container definition object usable with the CreateModel API if
@@ -7168,6 +7171,9 @@ def container_def(
71687171
env = {}
71697172
c_def = {"Image": image_uri, "Environment": env}
71707173

7174+
if additional_model_data_sources:
7175+
c_def["AdditionalModelDataSources"] = additional_model_data_sources
7176+
71717177
if isinstance(model_data_url, str) and (
71727178
not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
71737179
or accept_eula is None

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
11681168
"inference_config_name"
11691169
} == model_class_init_args - {
11701170
"model_data",
1171+
"additional_model_data_sources",
11711172
"self",
11721173
"name",
11731174
"resources",

tests/unit/sagemaker/jumpstart/model/test_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
759759
Please add the new argument to the skip set below,
760760
and reach out to JumpStart team."""
761761

762-
init_args_to_skip: Set[str] = set([])
762+
init_args_to_skip: Set[str] = set(["additional_model_data_sources"])
763763
deploy_args_to_skip: Set[str] = set(["kwargs"])
764764

765765
parent_class_init = Model.__init__

0 commit comments

Comments
 (0)