Skip to content

Commit e2daefc

Browse files
committed
feat: private util for model eula key
1 parent 7f7fa94 commit e2daefc

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/sagemaker/jumpstart/notebook_utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,29 @@ def get_model_url(
469469
s3_client=sagemaker_session.s3_client,
470470
)
471471
return model_specs.url
472+
473+
474+
def _get_model_eula_key(
475+
model_id: str,
476+
model_version: str,
477+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
478+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
479+
) -> str:
480+
"""Retrieve S3 key for EULA text for gated models, or None for non-gated models.
481+
482+
Args:
483+
model_id (str): The model ID for which to retrieve the EULA S3 key.
484+
model_version (str): The model version for which to retrieve the EULA S3 key.
485+
region (str): Optional. The region from which to retrieve metadata.
486+
(Default: JUMPSTART_DEFAULT_REGION_NAME)
487+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
488+
to retrieve the EULA S3 key.
489+
"""
490+
491+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
492+
region=region,
493+
model_id=model_id,
494+
version=model_version,
495+
s3_client=sagemaker_session.s3_client,
496+
)
497+
return model_specs.hosting_eula_key

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

+34
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
get_header_from_base_header,
1616
get_prototype_manifest,
1717
get_prototype_model_spec,
18+
get_special_model_spec,
1819
)
1920
from sagemaker.jumpstart.notebook_utils import (
2021
_generate_jumpstart_model_versions,
22+
_get_model_eula_key,
2123
get_model_url,
2224
list_jumpstart_frameworks,
2325
list_jumpstart_models,
@@ -698,3 +700,35 @@ def test_get_model_url(
698700
region=region,
699701
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
700702
)
703+
704+
705+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
706+
def test__get_model_eula_key(
707+
patched_get_model_specs: Mock,
708+
):
709+
710+
patched_get_model_specs.side_effect = get_special_model_spec
711+
712+
model_id, version = "gated_llama_neuron_model", "*"
713+
assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version)
714+
715+
model_id, version = "variant-model", "1.0.0"
716+
assert None == _get_model_eula_key(model_id, version)
717+
718+
region = "fake-region"
719+
720+
patched_get_model_specs.reset_mock()
721+
patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_special_model_spec(
722+
*largs,
723+
region="us-west-2",
724+
**{key: value for key, value in kwargs.items() if key != "region"},
725+
)
726+
727+
_get_model_eula_key(model_id, version, region=region)
728+
729+
patched_get_model_specs.assert_called_once_with(
730+
model_id=model_id,
731+
version=version,
732+
region=region,
733+
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
734+
)

0 commit comments

Comments
 (0)