Skip to content

Commit 5db63f5

Browse files
committed
feat: Keynote1 - Gated Models (aws#1246)
Co-authored-by: evakravi <[email protected]> fix: jumpstart unit-test (aws#1265)
1 parent 8462f1a commit 5db63f5

File tree

8 files changed

+436
-7
lines changed

8 files changed

+436
-7
lines changed

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def get_deploy_kwargs(
489489
tolerate_vulnerable_model: Optional[bool] = None,
490490
tolerate_deprecated_model: Optional[bool] = None,
491491
sagemaker_session: Optional[Session] = None,
492+
accept_eula: Optional[bool] = None,
492493
) -> JumpStartModelDeployKwargs:
493494
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
494495

@@ -516,6 +517,7 @@ def get_deploy_kwargs(
516517
tolerate_deprecated_model=tolerate_deprecated_model,
517518
tolerate_vulnerable_model=tolerate_vulnerable_model,
518519
sagemaker_session=sagemaker_session,
520+
accept_eula=accept_eula,
519521
)
520522

521523
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)

src/sagemaker/jumpstart/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def deploy(
448448
container_startup_health_check_timeout: Optional[int] = None,
449449
inference_recommendation_id: Optional[str] = None,
450450
explainer_config: Optional[ExplainerConfig] = None,
451+
accept_eula: Optional[bool] = None,
451452
) -> PredictorBase:
452453
"""Creates endpoint by calling base ``Model`` class `deploy` method.
453454
@@ -526,7 +527,11 @@ def deploy(
526527
(Default: None).
527528
explainer_config (Optional[sagemaker.explainer.ExplainerConfig]): Specifies online
528529
explainability configuration for use with Amazon SageMaker Clarify. (Default: None).
529-
530+
accept_eula (bool): For models that require a Model Access Config, specify True or
531+
False to indicate whether model terms of use have been accepted.
532+
The `accept_eula` value must be explicitly defined as `True` in order to
533+
accept the end-user license agreement (EULA) that some
534+
models require. (Default: None).
530535
"""
531536

532537
deploy_kwargs = get_deploy_kwargs(
@@ -553,6 +558,7 @@ def deploy(
553558
inference_recommendation_id=inference_recommendation_id,
554559
explainer_config=explainer_config,
555560
sagemaker_session=self.sagemaker_session,
561+
accept_eula=accept_eula,
556562
)
557563

558564
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
11201120
"tolerate_deprecated_model",
11211121
"sagemaker_session",
11221122
"training_instance_type",
1123+
"accept_eula",
11231124
]
11241125

11251126
SERIALIZATION_EXCLUSION_SET = {
@@ -1158,6 +1159,7 @@ def __init__(
11581159
tolerate_vulnerable_model: Optional[bool] = None,
11591160
sagemaker_session: Optional[Session] = None,
11601161
training_instance_type: Optional[str] = None,
1162+
accept_eula: Optional[bool] = None,
11611163
) -> None:
11621164
"""Instantiates JumpStartModelDeployKwargs object."""
11631165

@@ -1185,6 +1187,7 @@ def __init__(
11851187
self.tolerate_deprecated_model = tolerate_deprecated_model
11861188
self.sagemaker_session = sagemaker_session
11871189
self.training_instance_type = training_instance_type
1190+
self.accept_eula = accept_eula
11881191

11891192

11901193
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

src/sagemaker/model.py

+10
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def __init__(
379379
self.repacked_model_data = None
380380
self.content_types = None
381381
self.response_types = None
382+
self.accept_eula = None
382383

383384
@runnable_by_pipeline
384385
def register(
@@ -634,6 +635,7 @@ def prepare_container_def(
634635
self.repacked_model_data or self.model_data,
635636
deploy_env,
636637
image_config=self.image_config,
638+
accept_eula=getattr(self, "accept_eula", None),
637639
)
638640

639641
def is_repack(self) -> bool:
@@ -1260,6 +1262,7 @@ def deploy(
12601262
container_startup_health_check_timeout=None,
12611263
inference_recommendation_id=None,
12621264
explainer_config=None,
1265+
accept_eula: Optional[bool] = None,
12631266
**kwargs,
12641267
):
12651268
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1342,6 +1345,11 @@ def deploy(
13421345
a list of ``RealtimeInferenceRecommendations`` within ``DeploymentRecommendation``
13431346
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
13441347
configuration for use with Amazon SageMaker Clarify. Default: None.
1348+
accept_eula (bool): For models that require a Model Access Config, specify True or
1349+
False to indicate whether model terms of use have been accepted.
1350+
The `accept_eula` value must be explicitly defined as `True` in order to
1351+
accept the end-user license agreement (EULA) that some
1352+
models require. (Default: None).
13451353
Raises:
13461354
ValueError: If arguments combination check failed in these circumstances:
13471355
- If no role is specified or
@@ -1355,6 +1363,8 @@ def deploy(
13551363
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
13561364
is not None. Otherwise, return None.
13571365
"""
1366+
self.accept_eula = accept_eula
1367+
13581368
removed_kwargs("update_endpoint", kwargs)
13591369

13601370
self._init_sagemaker_session_if_does_not_exist(instance_type)

src/sagemaker/session.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -6128,7 +6128,14 @@ def update_args(args: Dict[str, Any], **kwargs):
61286128
args.update({key: value})
61296129

61306130

6131-
def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None):
6131+
def container_def(
6132+
image_uri,
6133+
model_data_url=None,
6134+
env=None,
6135+
container_mode=None,
6136+
image_config=None,
6137+
accept_eula=None,
6138+
):
61326139
"""Create a definition for executing a container as part of a SageMaker model.
61336140
61346141
Args:
@@ -6145,6 +6152,11 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
61456152
image_config (dict[str, str]): Specifies whether the image of model container is pulled
61466153
from ECR, or private registry in your VPC. By default it is set to pull model
61476154
container image from ECR. (default: None).
6155+
accept_eula (bool): For models that require a Model Access Config, specify True or
6156+
False to indicate whether model terms of use have been accepted.
6157+
The `accept_eula` value must be explicitly defined as `True` in order to
6158+
accept the end-user license agreement (EULA) that some
6159+
models require. (Default: None).
61486160
61496161
Returns:
61506162
dict[str, str]: A complete container definition object usable with the CreateModel API if
@@ -6154,9 +6166,28 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
61546166
env = {}
61556167
c_def = {"Image": image_uri, "Environment": env}
61566168

6157-
if isinstance(model_data_url, dict):
6158-
c_def["ModelDataSource"] = model_data_url
6159-
elif model_data_url:
6169+
if isinstance(model_data_url, str) and (
6170+
not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
6171+
or accept_eula is None
6172+
):
6173+
c_def["ModelDataUrl"] = model_data_url
6174+
6175+
elif isinstance(model_data_url, (dict, str)):
6176+
if isinstance(model_data_url, dict):
6177+
c_def["ModelDataSource"] = model_data_url
6178+
else:
6179+
c_def["ModelDataSource"] = {
6180+
"S3DataSource": {
6181+
"S3Uri": model_data_url,
6182+
"S3DataType": "S3Object",
6183+
"CompressionType": "Gzip",
6184+
}
6185+
}
6186+
if accept_eula is not None:
6187+
c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = {
6188+
"AcceptEula": accept_eula
6189+
}
6190+
elif model_data_url is not None:
61606191
c_def["ModelDataUrl"] = model_data_url
61616192

61626193
if container_mode:

0 commit comments

Comments
 (0)