Skip to content

Commit 2c7388c

Browse files
authored
Merge branch 'master' into master
2 parents 867743c + 2283102 commit 2c7388c

File tree

4 files changed

+60
-14
lines changed

4 files changed

+60
-14
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

+21-3
Original file line numberDiff line numberDiff line change
@@ -408,21 +408,39 @@ Example:
408408
step_args=step_args_register_model,
409409
)
410410
411-
CreateModelStep
411+
ModelStep
412412
````````````````
413413
Referable Property List:
414414

415415
- `DescribeModel`_
416416

417+
OR
418+
- `DescribeModelPackage`_
419+
417420
.. _DescribeModel: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModel.html#API_DescribeModel_ResponseSyntax
421+
.. _DescribeModelPackage: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModelPackage.html#API_DescribeModelPackage_ResponseSyntax
418422

419423
Example:
420424

425+
For model creation usecase:
426+
421427
.. code-block:: python
422428
423-
step_model = CreateModelStep(...)
424-
model_data = step_model.PrimaryContainer.ModelDataUrl
429+
create_model_step = ModelStep(
430+
name="MyModelCreationStep",
431+
step_args = model.create(...)
432+
)
433+
model_data = create_model_step.properties.PrimaryContainer.ModelDataUrl
434+
435+
For model registration usercase:
436+
437+
.. code-block:: python
425438
439+
register_model_step = ModelStep(
440+
name="MyModelRegistrationStep",
441+
step_args=model.register(...)
442+
)
443+
approval_status=register_model_step.properties.ModelApprovalStatus
426444
427445
LambdaStep
428446
`````````````

src/sagemaker/jumpstart/accessors.py

+5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.jumpstart.hub.utils import (
2626
construct_hub_model_arn_from_inputs,
2727
construct_hub_model_reference_arn_from_inputs,
28+
generate_hub_arn_for_init_kwargs,
2829
)
2930
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
3031
from sagemaker.session import Session
@@ -291,6 +292,10 @@ def get_model_specs(
291292
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
292293
if hub_arn:
293294
try:
295+
hub_arn = generate_hub_arn_for_init_kwargs(
296+
hub_name=hub_arn, region=region, session=sagemaker_session
297+
)
298+
294299
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
295300
hub_arn=hub_arn, model_name=model_id, version=version
296301
)

src/sagemaker/jumpstart/estimator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
validate_model_id_and_get_type,
4242
resolve_model_sagemaker_config_field,
4343
verify_model_region_and_return_specs,
44-
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
44+
remove_env_var_from_estimator_kwargs_if_model_access_config_present,
4545
get_model_access_config,
4646
get_hub_access_config,
4747
)
@@ -616,6 +616,7 @@ def _validate_model_id_and_get_type_hook():
616616
self.tolerate_vulnerable_model = estimator_init_kwargs.tolerate_vulnerable_model
617617
self.instance_count = estimator_init_kwargs.instance_count
618618
self.region = estimator_init_kwargs.region
619+
self.environment = estimator_init_kwargs.environment
619620
self.orig_predictor_cls = None
620621
self.role = estimator_init_kwargs.role
621622
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
@@ -693,7 +694,7 @@ def fit(
693694
accept the end-user license agreement (EULA) that some
694695
models require. (Default: None).
695696
"""
696-
self.model_access_config = get_model_access_config(accept_eula)
697+
self.model_access_config = get_model_access_config(accept_eula, self.environment)
697698
self.hub_access_config = get_hub_access_config(
698699
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
699700
)
@@ -713,7 +714,9 @@ def fit(
713714
config_name=self.config_name,
714715
hub_access_config=self.hub_access_config,
715716
)
716-
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)
717+
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
718+
self.init_kwargs, self.model_access_config
719+
)
717720

718721
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
719722

src/sagemaker/jumpstart/utils.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -1632,17 +1632,29 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632
return neo_bucket
16331633

16341634

1635-
def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
1636-
init_kwargs: dict, accept_eula: Optional[bool]
1635+
def remove_env_var_from_estimator_kwargs_if_model_access_config_present(
1636+
init_kwargs: dict, model_access_config: Optional[dict]
16371637
):
1638-
"""Remove env vars if access configs are used
1638+
"""Remove env vars if ModelAccessConfig is used
16391639
16401640
Args:
16411641
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
16421642
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16431643
"""
1644-
if accept_eula is not None and init_kwargs["environment"]:
1645-
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
1644+
if (
1645+
model_access_config is not None
1646+
and init_kwargs.get("environment") is not None
1647+
and init_kwargs.get("model_uri") is not None
1648+
):
1649+
if (
1650+
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1651+
in init_kwargs["environment"]
1652+
):
1653+
del init_kwargs["environment"][
1654+
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1655+
]
1656+
if "accept_eula" in init_kwargs["environment"]:
1657+
del init_kwargs["environment"]["accept_eula"]
16461658

16471659

16481660
def get_hub_access_config(hub_content_arn: Optional[str]):
@@ -1659,16 +1671,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]):
16591671
return hub_access_config
16601672

16611673

1662-
def get_model_access_config(accept_eula: Optional[bool]):
1674+
def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]):
16631675
"""Get access configs
16641676
16651677
Args:
16661678
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16671679
"""
1680+
env_var_eula = environment.get("accept_eula") if environment else None
1681+
if env_var_eula is not None and accept_eula is not None:
1682+
raise ValueError(
1683+
"Cannot pass in both accept_eula and environment variables. "
1684+
"Please remove the environment variable and pass in the accept_eula parameter."
1685+
)
1686+
1687+
model_access_config = None
1688+
if env_var_eula is not None:
1689+
model_access_config = {"AcceptEula": env_var_eula == "true"}
16681690
if accept_eula is not None:
16691691
model_access_config = {"AcceptEula": accept_eula}
1670-
else:
1671-
model_access_config = None
16721692

16731693
return model_access_config
16741694

0 commit comments

Comments
 (0)