|
15 | 15 | get_header_from_base_header,
|
16 | 16 | get_prototype_manifest,
|
17 | 17 | get_prototype_model_spec,
|
| 18 | + get_special_model_spec, |
18 | 19 | )
|
19 | 20 | from sagemaker.jumpstart.notebook_utils import (
|
20 | 21 | _generate_jumpstart_model_versions,
|
| 22 | + _get_model_eula_key, |
21 | 23 | get_model_url,
|
22 | 24 | list_jumpstart_frameworks,
|
23 | 25 | list_jumpstart_models,
|
@@ -698,3 +700,35 @@ def test_get_model_url(
|
698 | 700 | region=region,
|
699 | 701 | s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
|
700 | 702 | )
|
| 703 | + |
| 704 | + |
| 705 | +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 706 | +def test__get_model_eula_key( |
| 707 | + patched_get_model_specs: Mock, |
| 708 | +): |
| 709 | + |
| 710 | + patched_get_model_specs.side_effect = get_special_model_spec |
| 711 | + |
| 712 | + model_id, version = "gated_llama_neuron_model", "*" |
| 713 | + assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version) |
| 714 | + |
| 715 | + model_id, version = "variant-model", "1.0.0" |
| 716 | + assert None == _get_model_eula_key(model_id, version) |
| 717 | + |
| 718 | + region = "fake-region" |
| 719 | + |
| 720 | + patched_get_model_specs.reset_mock() |
| 721 | + patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_special_model_spec( |
| 722 | + *largs, |
| 723 | + region="us-west-2", |
| 724 | + **{key: value for key, value in kwargs.items() if key != "region"}, |
| 725 | + ) |
| 726 | + |
| 727 | + _get_model_eula_key(model_id, version, region=region) |
| 728 | + |
| 729 | + patched_get_model_specs.assert_called_once_with( |
| 730 | + model_id=model_id, |
| 731 | + version=version, |
| 732 | + region=region, |
| 733 | + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, |
| 734 | + ) |
0 commit comments