From a93fbc7c78831b30d9c917e4fa49f0af2da211d5 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 21 Dec 2023 20:30:22 +0000 Subject: [PATCH 01/12] feat: parallelize notebook search utils --- src/sagemaker/jumpstart/notebook_utils.py | 219 ++++++++---------- .../jumpstart/test_notebook_utils.py | 139 +++++------ 2 files changed, 159 insertions(+), 199 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 732dbf4b83..78d1391dd7 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -14,13 +14,15 @@ from __future__ import absolute_import import copy +from concurrent.futures import ThreadPoolExecutor, as_completed + from functools import cmp_to_key -import os +import json from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict from packaging.version import Version from sagemaker.jumpstart import accessors from sagemaker.jumpstart.constants import ( - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope @@ -31,7 +33,8 @@ SpecialSupportedFilterKeys, ) from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression -from sagemaker.jumpstart.utils import get_sagemaker_version +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version def _compare_model_version_tuples( # pylint: disable=too-many-return-statements @@ -285,160 +288,130 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin results. (Default: False). """ - class _ModelSearchContext: - """Context manager for conducting model searches.""" - - def __init__(self): - """Initialize context manager.""" - - self.old_disable_js_logging_env_var_value = os.environ.get( - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING - ) - - def __enter__(self, *args, **kwargs): - """Enter context. - - Disable JumpStart logs to avoid excessive logging. - """ + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) - os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true" + if isinstance(filter, str): + filter = Identity(filter) - def __exit__(self, *args, **kwargs): - """Exit context. + manifest_keys = set(models_manifest_list[0].__slots__) - Restore JumpStart logging settings, and reset cache so - new logs would appear for models previously searched. - """ + all_keys: Set[str] = set() - if self.old_disable_js_logging_env_var_value: - os.environ[ - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING - ] = self.old_disable_js_logging_env_var_value - else: - os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None) - accessors.JumpStartModelsAccessor.reset_cache() + model_filters: Set[ModelFilter] = set() - with _ModelSearchContext(): - - if isinstance(filter, str): - filter = Identity(filter) - - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) - manifest_keys = set(models_manifest_list[0].__slots__) + for operator in _model_filter_in_operator_generator(filter): + model_filter = operator.unresolved_value + key = model_filter.key + all_keys.add(key) + model_filters.add(model_filter) - all_keys: Set[str] = set() + for key in all_keys: + if "." in key: + raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').") - model_filters: Set[ModelFilter] = set() + metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS - for operator in _model_filter_in_operator_generator(filter): - model_filter = operator.unresolved_value - key = model_filter.key - all_keys.add(key) - model_filters.add(model_filter) + required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) + possible_spec_keys = metadata_filter_keys - manifest_keys - for key in all_keys: - if "." in key: - raise NotImplementedError( - f"No support for multiple level metadata indexing ('{key}')." - ) + is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys + is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys - metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS + def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: - required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) - possible_spec_keys = metadata_filter_keys - manifest_keys + copied_filter = copy.deepcopy(filter) - unrecognized_keys: Set[str] = set() + manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} - is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys - is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys + model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} - for model_manifest in models_manifest_list: + for val in required_manifest_keys: + manifest_specs_cached_values[val] = getattr(model_manifest, val) - copied_filter = copy.deepcopy(filter) + if is_task_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.TASK + ] = extract_framework_task_model(model_manifest.model_id)[1] - manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} + if is_framework_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.FRAMEWORK + ] = extract_framework_task_model(model_manifest.model_id)[0] - model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): + return None - for val in required_manifest_keys: - manifest_specs_cached_values[val] = getattr(model_manifest, val) + _populate_model_filters_to_resolved_values( + manifest_specs_cached_values, + model_filters_to_resolved_values, + model_filters, + ) - if is_task_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.TASK - ] = extract_framework_task_model(model_manifest.model_id)[1] + _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) - if is_framework_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.FRAMEWORK - ] = extract_framework_task_model(model_manifest.model_id)[0] + copied_filter.eval() - if Version(model_manifest.min_version) > Version(get_sagemaker_version()): - continue + if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: + if copied_filter.resolved_value == BooleanValues.TRUE: + return (model_manifest.model_id, model_manifest.version) + return None - _populate_model_filters_to_resolved_values( - manifest_specs_cached_values, - model_filters_to_resolved_values, - model_filters, + if copied_filter.resolved_value == BooleanValues.UNEVALUATED: + raise RuntimeError( + "Filter expression in unevaluated state after using " + "values from model manifest. Model ID and version that " + f"is failing: {(model_manifest.model_id, model_manifest.version)}." ) + copied_filter_2 = copy.deepcopy(filter) - _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) - - copied_filter.eval() - - if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: - if copied_filter.resolved_value == BooleanValues.TRUE: - yield (model_manifest.model_id, model_manifest.version) - continue - - if copied_filter.resolved_value == BooleanValues.UNEVALUATED: - raise RuntimeError( - "Filter expression in unevaluated state after using " - "values from model manifest. Model ID and version that " - f"is failing: {(model_manifest.model_id, model_manifest.version)}." + model_specs = JumpStartModelSpecs( + json.loads( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file( + get_jumpstart_content_bucket(), model_manifest.spec_key ) - copied_filter_2 = copy.deepcopy(filter) - - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, - model_id=model_manifest.model_id, - version=model_manifest.version, ) + ) - model_specs_keys = set(model_specs.__slots__) + for val in possible_spec_keys: + if hasattr(model_specs, val): + manifest_specs_cached_values[val] = getattr(model_specs, val) - unrecognized_keys -= model_specs_keys - unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys - unrecognized_keys.update(unrecognized_keys_for_single_spec) + _populate_model_filters_to_resolved_values( + manifest_specs_cached_values, + model_filters_to_resolved_values, + model_filters, + ) + _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) - for val in possible_spec_keys: - if hasattr(model_specs, val): - manifest_specs_cached_values[val] = getattr(model_specs, val) + copied_filter_2.eval() - _populate_model_filters_to_resolved_values( - manifest_specs_cached_values, - model_filters_to_resolved_values, - model_filters, - ) - _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) + if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: + if copied_filter_2.resolved_value == BooleanValues.TRUE or ( + BooleanValues.UNKNOWN and list_incomplete_models + ): + return (model_manifest.model_id, model_manifest.version) + return None - copied_filter_2.eval() + raise RuntimeError( + "Filter expression in unevaluated state after using values from model specs. " + "Model ID and version that is failing: " + f"{(model_manifest.model_id, model_manifest.version)}." + ) - if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: - if copied_filter_2.resolved_value == BooleanValues.TRUE or ( - BooleanValues.UNKNOWN and list_incomplete_models - ): - yield (model_manifest.model_id, model_manifest.version) - continue + max_memory = int(100 * 1e6) + average_memory_per_thread = int(25 * 1e3) + max_workers = int(max_memory / average_memory_per_thread) - raise RuntimeError( - "Filter expression in unevaluated state after using values from model specs. " - "Model ID and version that is failing: " - f"{(model_manifest.model_id, model_manifest.version)}." - ) + executor = ThreadPoolExecutor(max_workers=max_workers) + + futures = [] + for header in models_manifest_list: + futures.append(executor.submit(evaluate_model, header)) - if len(unrecognized_keys) > 0: - raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") + for future in as_completed(futures): + result = future.result() + if result: + yield result def get_model_url( diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 181310a507..c92fd0258b 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +import json from unittest import TestCase from unittest.mock import Mock, patch @@ -22,6 +23,7 @@ ) +@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") @@ -29,10 +31,14 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models: Mock, patched_get_model_specs: Mock, patched_get_manifest: Mock, + patched_read_s3_file: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions + patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( + get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + ) assert list_jumpstart_scripts() == sorted(["inference", "training"]) patched_get_model_specs.assert_not_called() @@ -63,7 +69,7 @@ def test_list_jumpstart_scripts( assert list_jumpstart_scripts(**kwargs) == [] patched_generate_jumpstart_models.assert_called_once_with(**kwargs) patched_get_manifest.assert_called_once() - assert patched_get_model_specs.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -183,11 +189,13 @@ def test_list_jumpstart_models_simple_case( patched_get_model_specs.assert_not_called() @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_script_filter( - self, patched_get_model_specs: Mock, patched_get_manifest: Mock + self, patched_read_s3_file: Mock, patched_get_manifest: Mock ): - patched_get_model_specs.side_effect = get_prototype_model_spec + patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( + get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + ) patched_get_manifest.side_effect = get_prototype_manifest manifest_length = len(get_prototype_manifest()) @@ -195,19 +203,19 @@ def test_list_jumpstart_models_script_filter( for val in vals: kwargs = {"filter": f"training_supported == {val}"} list_jumpstart_models(**kwargs) - assert patched_get_model_specs.call_count == manifest_length + assert patched_read_s3_file.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) - assert patched_get_model_specs.call_count == manifest_length + assert patched_read_s3_file.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} assert list_jumpstart_models(**kwargs) == [ @@ -220,16 +228,16 @@ def test_list_jumpstart_models_script_filter( ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - assert patched_get_model_specs.call_count == manifest_length + assert patched_read_s3_file.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"training_supported not in {vals}"} models = list_jumpstart_models(**kwargs) assert [] == models - assert patched_get_model_specs.call_count == manifest_length + assert patched_read_s3_file.call_count == manifest_length patched_get_manifest.assert_called_once() @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -288,11 +296,13 @@ def test_list_jumpstart_models_task_filter( patched_get_manifest.assert_called_once() @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_framework_filter( - self, patched_get_model_specs: Mock, patched_get_manifest: Mock + self, patched_read_s3_file: Mock, patched_get_manifest: Mock ): - patched_get_model_specs.side_effect = get_prototype_model_spec + patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( + get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + ) patched_get_manifest.side_effect = get_prototype_manifest vals = [ @@ -307,19 +317,19 @@ def test_list_jumpstart_models_framework_filter( for val in vals: kwargs = {"filter": f"framework == {val}"} list_jumpstart_models(**kwargs) - patched_get_model_specs.assert_not_called() + patched_read_s3_file.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"framework != {val}"} list_jumpstart_models(**kwargs) - patched_get_model_specs.assert_not_called() + patched_read_s3_file.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"framework in {vals}", "list_versions": True} assert list_jumpstart_models(**kwargs) == [ @@ -331,18 +341,18 @@ def test_list_jumpstart_models_framework_filter( ("sklearn-classification-linear", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - patched_get_model_specs.assert_not_called() + patched_read_s3_file.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = {"filter": f"framework not in {vals}", "list_versions": True} models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = { "filter": And(f"framework not in {vals}", "training_supported is True"), @@ -350,11 +360,11 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models - patched_get_model_specs.assert_called_once() + patched_read_s3_file.assert_called_once() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() kwargs = { "filter": And( @@ -364,7 +374,7 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [] == models - patched_get_model_specs.assert_not_called() + patched_read_s3_file.assert_not_called() patched_get_manifest.assert_called_once() @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -380,31 +390,6 @@ def test_list_jumpstart_models_region( patched_get_manifest.assert_called_once_with(region="some-region") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @patch("sagemaker.jumpstart.notebook_utils.get_sagemaker_version") - @patch("sagemaker.jumpstart.notebook_utils.accessors.JumpStartModelsAccessor.reset_cache") - @patch.dict("os.environ", {}) - @patch("logging.StreamHandler.emit") - @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False) - def test_list_jumpstart_models_disables_logging_resets_cache( - self, - patched_emit: Mock, - patched_reset_cache: Mock, - patched_get_sagemaker_version: Mock, - patched_get_model_specs: Mock, - patched_get_manifest: Mock, - ): - patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest - - patched_get_sagemaker_version.return_value = "3.0.0" - - list_jumpstart_models("deprecate_warn_message is blah") - - patched_emit.assert_not_called() - patched_reset_cache.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_old_models( @@ -477,83 +462,83 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ) == list_jumpstart_models(list_versions=True) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_vulnerable_models( self, - patched_get_model_specs: Mock, + patched_read_s3_file: Mock, patched_get_manifest: Mock, ): patched_get_manifest.side_effect = get_prototype_manifest - def vulnerable_inference_model_spec(*args, **kwargs): - spec = get_prototype_model_spec(*args, **kwargs) + def vulnerable_inference_model_spec(bucket, key) -> str: + spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.inference_vulnerable = True - return spec + return json.dumps(spec.to_json()) - def vulnerable_training_model_spec(*args, **kwargs): - spec = get_prototype_model_spec(*args, **kwargs) + def vulnerable_training_model_spec(bucket, key): + spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.training_vulnerable = True - return spec + return json.dumps(spec.to_json()) - patched_get_model_specs.side_effect = vulnerable_inference_model_spec + patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models( And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_get_model_specs.call_count == num_specs + assert patched_read_s3_file.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() - patched_get_model_specs.side_effect = vulnerable_training_model_spec + patched_read_s3_file.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_get_model_specs.call_count == num_specs + assert patched_read_s3_file.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() assert [] != list_jumpstart_models() - assert patched_get_model_specs.call_count == 0 + assert patched_read_s3_file.call_count == 0 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_deprecated_models( self, - patched_get_model_specs: Mock, + patched_read_s3_file: Mock, patched_get_manifest: Mock, ): patched_get_manifest.side_effect = get_prototype_manifest - def deprecated_model_spec(*args, **kwargs): - spec = get_prototype_model_spec(*args, **kwargs) + def deprecated_model_spec(bucket, key) -> str: + spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.deprecated = True - return spec + return json.dumps(spec.to_json()) - patched_get_model_specs.side_effect = deprecated_model_spec + patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_get_model_specs.call_count == num_specs + assert patched_read_s3_file.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() - patched_get_model_specs.reset_mock() + patched_read_s3_file.reset_mock() assert [] != list_jumpstart_models() - assert patched_get_model_specs.call_count == 0 + assert patched_read_s3_file.call_count == 0 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -581,13 +566,15 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_complex_queries( self, - patched_get_model_specs: Mock, + patched_read_s3_file: Mock, patched_get_manifest: Mock, ): - patched_get_model_specs.side_effect = get_prototype_model_spec + patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( + get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + ) patched_get_manifest.side_effect = get_prototype_manifest assert list_jumpstart_models( From 448081c9e0add105a1df7340a221773d16fecf8b Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 21 Dec 2023 21:16:46 +0000 Subject: [PATCH 02/12] chore: raise exception in notebook utils if thread has error --- src/sagemaker/jumpstart/notebook_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 78d1391dd7..2a0b9e3385 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -409,6 +409,9 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, futures.append(executor.submit(evaluate_model, header)) for future in as_completed(futures): + error = future.exception() + if error: + raise error result = future.result() if result: yield result From f6fe6bec17db1eedd90d36e7d5b28c0f46328805 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 21 Dec 2023 21:18:10 +0000 Subject: [PATCH 03/12] chore: improve variable name --- src/sagemaker/jumpstart/notebook_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 2a0b9e3385..5fd4c3a638 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -398,9 +398,9 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, f"{(model_manifest.model_id, model_manifest.version)}." ) - max_memory = int(100 * 1e6) - average_memory_per_thread = int(25 * 1e3) - max_workers = int(max_memory / average_memory_per_thread) + max_memory_bytes = int(100 * 1e6) + average_memory_bytes_per_thread = int(25 * 1e3) + max_workers = int(max_memory_bytes / average_memory_bytes_per_thread) executor = ThreadPoolExecutor(max_workers=max_workers) From 1f808086fa1985a76319537ec10aa0f6434edb95 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 22 Dec 2023 16:27:31 +0000 Subject: [PATCH 04/12] fix: not passing region to get jumpstart bucket --- src/sagemaker/jumpstart/notebook_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 5fd4c3a638..a3480b5e8a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -367,7 +367,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, model_specs = JumpStartModelSpecs( json.loads( DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file( - get_jumpstart_content_bucket(), model_manifest.spec_key + get_jumpstart_content_bucket(region), model_manifest.spec_key ) ) ) From 50dd33cea70e2ef73d155662055e6168246445ae Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 22 Dec 2023 19:48:06 +0000 Subject: [PATCH 05/12] chore: add sagemaker session to notebook utils --- src/sagemaker/jumpstart/notebook_utils.py | 52 +++++++++-- .../jumpstart/test_notebook_utils.py | 92 ++++++++++++++----- 2 files changed, 112 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index a3480b5e8a..015859801e 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version +from sagemaker.session import Session def _compare_model_version_tuples( # pylint: disable=too-many-return-statements @@ -137,6 +138,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List tasks for JumpStart, and optionally apply filters to result. @@ -148,10 +150,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to + use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ tasks: Set[str] = set() - for model_id, _ in _generate_jumpstart_model_versions(filter=filter, region=region): + for model_id, _ in _generate_jumpstart_model_versions( + filter=filter, region=region, sagemaker_session=sagemaker_session + ): _, task, _ = extract_framework_task_model(model_id) tasks.add(task) return sorted(list(tasks)) @@ -160,6 +166,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin def list_jumpstart_frameworks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List frameworks for JumpStart, and optionally apply filters to result. @@ -171,10 +178,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session + to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ frameworks: Set[str] = set() - for model_id, _ in _generate_jumpstart_model_versions(filter=filter, region=region): + for model_id, _ in _generate_jumpstart_model_versions( + filter=filter, region=region, sagemaker_session=sagemaker_session + ): framework, _, _ = extract_framework_task_model(model_id) frameworks.add(framework) return sorted(list(frameworks)) @@ -183,6 +194,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List scripts for JumpStart, and optionally apply filters to result. @@ -194,6 +206,8 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to + use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() @@ -201,12 +215,15 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin return sorted([e.value for e in JumpStartScriptScope]) scripts: Set[str] = set() - for model_id, version in _generate_jumpstart_model_versions(filter=filter, region=region): + for model_id, version in _generate_jumpstart_model_versions( + filter=filter, region=region, sagemaker_session=sagemaker_session + ): scripts.add(JumpStartScriptScope.INFERENCE) model_specs = accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version, + s3_client=sagemaker_session.s3_client, ) if model_specs.training_supported: scripts.add(JumpStartScriptScope.TRAINING) @@ -222,6 +239,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[Union[Tuple[str], Tuple[str, str]]]: """List models for JumpStart, and optionally apply filters to result. @@ -241,11 +259,16 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin versions should be included in the returned result. (Default: False). list_versions (bool): Optional. True if versions for models should be returned in addition to the id of the model. (Default: False). + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use + to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ model_id_version_dict: Dict[str, List[str]] = dict() for model_id, version in _generate_jumpstart_model_versions( - filter=filter, region=region, list_incomplete_models=list_incomplete_models + filter=filter, + region=region, + list_incomplete_models=list_incomplete_models, + sagemaker_session=sagemaker_session, ): if model_id not in model_id_version_dict: model_id_version_dict[model_id] = list() @@ -271,6 +294,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, list_incomplete_models: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: """Generate models for JumpStart, and optionally apply filters to result. @@ -286,9 +310,13 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. (Default: False). + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session + to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=sagemaker_session.s3_client + ) if isinstance(filter, str): filter = Identity(filter) @@ -366,7 +394,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, model_specs = JumpStartModelSpecs( json.loads( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file( + sagemaker_session.read_s3_file( get_jumpstart_content_bucket(region), model_manifest.spec_key ) ) @@ -418,7 +446,10 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, def get_model_url( - model_id: str, model_version: str, region: str = JUMPSTART_DEFAULT_REGION_NAME + model_id: str, + model_version: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieve web url describing pretrained model. @@ -427,9 +458,14 @@ def get_model_url( model_version (str): The model version for which to retrieve the url. region (str): Optional. The region from which to retrieve metadata. (Default: JUMPSTART_DEFAULT_REGION_NAME) + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use + to retrieve the model url. """ model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=sagemaker_session.s3_client, ) return model_specs.url diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c92fd0258b..8aae4c36a8 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -5,7 +5,10 @@ from unittest.mock import Mock, patch import pytest -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.filters import And, Identity, Not, Or from tests.unit.sagemaker.jumpstart.constants import PROTOTYPICAL_MODEL_SPECS_DICT from tests.unit.sagemaker.jumpstart.utils import ( @@ -34,7 +37,9 @@ def test_list_jumpstart_scripts( patched_read_s3_file: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() @@ -54,7 +59,9 @@ def test_list_jumpstart_scripts( "region": "sa-east-1", } assert list_jumpstart_scripts(**kwargs) == sorted(["inference", "training"]) - patched_generate_jumpstart_models.assert_called_once_with(**kwargs) + patched_generate_jumpstart_models.assert_called_once_with( + **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + ) patched_get_manifest.assert_called_once() assert patched_get_model_specs.call_count == 1 @@ -67,7 +74,9 @@ def test_list_jumpstart_scripts( "region": "sa-east-1", } assert list_jumpstart_scripts(**kwargs) == [] - patched_generate_jumpstart_models.assert_called_once_with(**kwargs) + patched_generate_jumpstart_models.assert_called_once_with( + **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + ) patched_get_manifest.assert_called_once() assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) @@ -81,7 +90,9 @@ def test_list_jumpstart_tasks( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions assert list_jumpstart_tasks() == sorted( @@ -107,7 +118,9 @@ def test_list_jumpstart_tasks( "region": "sa-east-1", } assert list_jumpstart_tasks(**kwargs) == ["ic"] - patched_generate_jumpstart_models.assert_called_once_with(**kwargs) + patched_generate_jumpstart_models.assert_called_once_with( + **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + ) patched_get_manifest.assert_called_once() patched_get_model_specs.assert_not_called() @@ -121,7 +134,9 @@ def test_list_jumpstart_frameworks( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions assert list_jumpstart_frameworks() == sorted( @@ -161,7 +176,9 @@ def test_list_jumpstart_frameworks( ] ) - patched_generate_jumpstart_models.assert_called_once_with(**kwargs) + patched_generate_jumpstart_models.assert_called_once_with( + **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + ) patched_get_manifest.assert_called_once() patched_get_model_specs.assert_not_called() @@ -173,7 +190,9 @@ def test_list_jumpstart_models_simple_case( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) assert list_jumpstart_models(list_versions=True) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), @@ -196,7 +215,9 @@ def test_list_jumpstart_models_script_filter( patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) manifest_length = len(get_prototype_manifest()) vals = [True, False] @@ -246,7 +267,9 @@ def test_list_jumpstart_models_task_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) vals = [ "classification", @@ -303,7 +326,9 @@ def test_list_jumpstart_models_framework_filter( patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) vals = [ "catboost", @@ -384,11 +409,15 @@ def test_list_jumpstart_models_region( ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = lambda region: get_prototype_manifest(region="us-west-2") + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region="us-west-2" + ) list_jumpstart_models(region="some-region") - patched_get_manifest.assert_called_once_with(region="some-region") + patched_get_manifest.assert_called_once_with( + region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client + ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -404,7 +433,9 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): for version in ["2.400.0", "1.4.0", "2.5.1", "1.300.0"] ] - patched_get_manifest.side_effect = get_manifest_more_versions + patched_get_manifest.side_effect = ( + lambda region, *args, **kwargs: get_manifest_more_versions(region) + ) assert [ ("catboost-classification-model", "2.400.0"), @@ -469,14 +500,16 @@ def test_list_jumpstart_models_vulnerable_models( patched_get_manifest: Mock, ): - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) - def vulnerable_inference_model_spec(bucket, key) -> str: + def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str: spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.inference_vulnerable = True return json.dumps(spec.to_json()) - def vulnerable_training_model_spec(bucket, key): + def vulnerable_training_model_spec(bucket, key, *args, **kwargs): spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.training_vulnerable = True return json.dumps(spec.to_json()) @@ -518,9 +551,11 @@ def test_list_jumpstart_models_deprecated_models( patched_get_manifest: Mock, ): - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) - def deprecated_model_spec(bucket, key) -> str: + def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") spec.deprecated = True return json.dumps(spec.to_json()) @@ -548,7 +583,9 @@ def test_list_jumpstart_models_no_versions( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) all_model_ids = [ "catboost-classification-model", @@ -575,7 +612,9 @@ def test_list_jumpstart_models_complex_queries( patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) assert list_jumpstart_models( Or( @@ -618,7 +657,9 @@ def test_list_jumpstart_models_multiple_level_index( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = get_prototype_manifest + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) with pytest.raises(NotImplementedError): list_jumpstart_models("hosting_ecr_specs.py_version == py3") @@ -652,5 +693,8 @@ def test_get_model_url( get_model_url(model_id, version, region=region) patched_get_model_specs.assert_called_once_with( - model_id=model_id, version=version, region=region + model_id=model_id, + version=version, + region=region, + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, ) From 751fee3d9a2149996381102b70183880fdd52729 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 27 Dec 2023 16:15:55 +0000 Subject: [PATCH 06/12] chore: address PR comments --- src/sagemaker/jumpstart/notebook_utils.py | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 015859801e..ecbef8cb3f 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -37,6 +37,8 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version from sagemaker.session import Session +MAX_SEARCH_WORKERS = int(100 * 1e6 / 25 * 1e3) # max 100MB total memory, 25kB per thread) + def _compare_model_version_tuples( # pylint: disable=too-many-return-statements model_version_1: Optional[Tuple[str, str]] = None, @@ -392,6 +394,9 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, ) copied_filter_2 = copy.deepcopy(filter) + # spec is downloaded to thread's memory. since each thread + # accesses a unique s3 spec, there is no need to use the JS caching utils. + # spec only stays in memory for lifecycle of thread. model_specs = JumpStartModelSpecs( json.loads( sagemaker_session.read_s3_file( @@ -426,23 +431,18 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, f"{(model_manifest.model_id, model_manifest.version)}." ) - max_memory_bytes = int(100 * 1e6) - average_memory_bytes_per_thread = int(25 * 1e3) - max_workers = int(max_memory_bytes / average_memory_bytes_per_thread) - - executor = ThreadPoolExecutor(max_workers=max_workers) - - futures = [] - for header in models_manifest_list: - futures.append(executor.submit(evaluate_model, header)) - - for future in as_completed(futures): - error = future.exception() - if error: - raise error - result = future.result() - if result: - yield result + with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor: + futures = [] + for header in models_manifest_list: + futures.append(executor.submit(evaluate_model, header)) + + for future in as_completed(futures): + error = future.exception() + if error: + raise error + result = future.result() + if result: + yield result def get_model_url( From 0f50da5bea42b25ef427ba856fd5b05792828053 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 27 Dec 2023 22:47:17 +0000 Subject: [PATCH 07/12] feat: add support for includes, begins with, ends with --- src/sagemaker/jumpstart/filters.py | 158 ++++++++++++++---- .../unit/sagemaker/jumpstart/test_filters.py | 46 +++++ 2 files changed, 172 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index 56ef12a148..dc19157ffc 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from ast import literal_eval from enum import Enum -from typing import Dict, List, Union, Any +from typing import Dict, List, Optional, Union, Any from sagemaker.jumpstart.types import JumpStartDataHolderType @@ -38,6 +38,10 @@ class FilterOperators(str, Enum): NOT_EQUALS = "not_equals" IN = "in" NOT_IN = "not_in" + INCLUDES = "includes" + NOT_INCLUDES = "not_includes" + BEGINS_WITH = "begins_with" + ENDS_WITH = "ends_with" class SpecialSupportedFilterKeys(str, Enum): @@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum): FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"], FilterOperators.IN: ["in"], FilterOperators.NOT_IN: ["not in"], + FilterOperators.INCLUDES: ["includes", "contains"], + FilterOperators.NOT_INCLUDES: ["not includes", "not contains"], + FilterOperators.BEGINS_WITH: ["begins with", "starts with"], + FilterOperators.ENDS_WITH: ["ends with"], } @@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum): ) ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = ( - list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS])) + list( + map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]) + ) + + list( + map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]) + ) + + list( + map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]) + ) + + list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES])) + + list( + map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]) + ) + list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN])) + list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS])) + list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN])) @@ -428,9 +448,90 @@ def parse_filter_string(filter_string: str) -> ModelFilter: raise ValueError(f"Cannot parse filter string: {filter_string}") +def _negate_boolean(boolean: BooleanValues) -> BooleanValues: + if boolean == BooleanValues.TRUE: + return BooleanValues.FALSE + if boolean == BooleanValues.FALSE: + return BooleanValues.TRUE + return boolean + + +def _evaluate_filter_expression_equals( + model_filter: ModelFilter, + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], +) -> BooleanValues: + if cached_model_value is None: + return BooleanValues.FALSE + model_filter_value = model_filter.value + if isinstance(cached_model_value, bool): + cached_model_value = str(cached_model_value).lower() + model_filter_value = model_filter.value.lower() + if str(model_filter_value) == str(cached_model_value): + return BooleanValues.TRUE + return BooleanValues.FALSE + + +def _evaluate_filter_expression_in( + model_filter: ModelFilter, + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], +) -> BooleanValues: + if cached_model_value is None: + return BooleanValues.FALSE + py_obj = model_filter.value + try: + py_obj = literal_eval(py_obj) + try: + iter(py_obj) + except TypeError: + return BooleanValues.FALSE + except Exception: + pass + if isinstance(cached_model_value, list): + return BooleanValues.FALSE + if cached_model_value in py_obj: + return BooleanValues.TRUE + return BooleanValues.FALSE + + +def _evaluate_filter_expression_includes( + model_filter: ModelFilter, + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], +) -> BooleanValues: + if cached_model_value is None: + return BooleanValues.FALSE + filter_value = str(model_filter.value) + if filter_value in cached_model_value: + return BooleanValues.TRUE + return BooleanValues.FALSE + + +def _evaluate_filter_expression_begins_with( + model_filter: ModelFilter, + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], +) -> BooleanValues: + if cached_model_value is None: + return BooleanValues.FALSE + filter_value = str(model_filter.value) + if cached_model_value.startswith(filter_value): + return BooleanValues.TRUE + return BooleanValues.FALSE + + +def _evaluate_filter_expression_ends_with( + model_filter: ModelFilter, + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], +) -> BooleanValues: + if cached_model_value is None: + return BooleanValues.FALSE + filter_value = str(model_filter.value) + if cached_model_value.endswith(filter_value): + return BooleanValues.TRUE + return BooleanValues.FALSE + + def evaluate_filter_expression( # pylint: disable=too-many-return-statements model_filter: ModelFilter, - cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]], + cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: """Evaluates model filter with cached model spec value, returns boolean. @@ -440,36 +541,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements evaluate the filter. """ if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]: - model_filter_value = model_filter.value - if isinstance(cached_model_value, bool): - cached_model_value = str(cached_model_value).lower() - model_filter_value = model_filter.value.lower() - if str(model_filter_value) == str(cached_model_value): - return BooleanValues.TRUE - return BooleanValues.FALSE + return _evaluate_filter_expression_equals(model_filter, cached_model_value) + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]: - if isinstance(cached_model_value, bool): - cached_model_value = str(cached_model_value).lower() - model_filter.value = model_filter.value.lower() - if str(model_filter.value) == str(cached_model_value): - return BooleanValues.FALSE - return BooleanValues.TRUE + return _negate_boolean(_evaluate_filter_expression_equals(model_filter, cached_model_value)) + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]: - py_obj = literal_eval(model_filter.value) - try: - iter(py_obj) - except TypeError: - return BooleanValues.FALSE - if cached_model_value in py_obj: - return BooleanValues.TRUE - return BooleanValues.FALSE + return _evaluate_filter_expression_in(model_filter, cached_model_value) + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]: - py_obj = literal_eval(model_filter.value) - try: - iter(py_obj) - except TypeError: - return BooleanValues.TRUE - if cached_model_value in py_obj: - return BooleanValues.FALSE - return BooleanValues.TRUE + return _negate_boolean(_evaluate_filter_expression_in(model_filter, cached_model_value)) + + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]: + return _evaluate_filter_expression_includes(model_filter, cached_model_value) + + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]: + return _negate_boolean( + _evaluate_filter_expression_includes(model_filter, cached_model_value) + ) + + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]: + return _evaluate_filter_expression_begins_with(model_filter, cached_model_value) + + if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]: + return _evaluate_filter_expression_ends_with(model_filter, cached_model_value) + raise RuntimeError(f"Bad operator: {model_filter.operator}") diff --git a/tests/unit/sagemaker/jumpstart/test_filters.py b/tests/unit/sagemaker/jumpstart/test_filters.py index 31055745b1..a984509f9b 100644 --- a/tests/unit/sagemaker/jumpstart/test_filters.py +++ b/tests/unit/sagemaker/jumpstart/test_filters.py @@ -143,6 +143,10 @@ def test_not_equals(self): def test_in(self): + assert BooleanValues.TRUE == evaluate_filter_expression( + ModelFilter(key="hello", operator="in", value="daddy"), "dad" + ) + assert BooleanValues.TRUE == evaluate_filter_expression( ModelFilter(key="hello", operator="in", value='["mom", "dad"]'), "dad" ) @@ -169,6 +173,10 @@ def test_in(self): def test_not_in(self): + assert BooleanValues.FALSE == evaluate_filter_expression( + ModelFilter(key="hello", operator="not in", value="daddy"), "dad" + ) + assert BooleanValues.FALSE == evaluate_filter_expression( ModelFilter(key="hello", operator="not in", value='["mom", "dad"]'), "dad" ) @@ -193,6 +201,44 @@ def test_not_in(self): ModelFilter(key="hello", operator="not in", value='["mom", "fsdfdsfsd"]'), False ) + def test_includes(self): + + assert BooleanValues.TRUE == evaluate_filter_expression( + ModelFilter(key="hello", operator="includes", value="dad"), "daddy" + ) + + assert BooleanValues.TRUE == evaluate_filter_expression( + ModelFilter(key="hello", operator="includes", value="dad"), ["dad"] + ) + + def test_not_includes(self): + + assert BooleanValues.FALSE == evaluate_filter_expression( + ModelFilter(key="hello", operator="not includes", value="dad"), "daddy" + ) + + assert BooleanValues.FALSE == evaluate_filter_expression( + ModelFilter(key="hello", operator="not includes", value="dad"), ["dad"] + ) + + def test_begins_with(self): + assert BooleanValues.TRUE == evaluate_filter_expression( + ModelFilter(key="hello", operator="begins with", value="dad"), "daddy" + ) + + assert BooleanValues.FALSE == evaluate_filter_expression( + ModelFilter(key="hello", operator="begins with", value="mm"), "mommy" + ) + + def test_ends_with(self): + assert BooleanValues.TRUE == evaluate_filter_expression( + ModelFilter(key="hello", operator="ends with", value="car"), "racecar" + ) + + assert BooleanValues.FALSE == evaluate_filter_expression( + ModelFilter(key="hello", operator="begins with", value="ace"), "racecar" + ) + def test_parse_filter_string(): From 7f7fa949a0cb4dbe8b4f98b7b2e2af757fe39937 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 28 Dec 2023 01:03:43 +0000 Subject: [PATCH 08/12] fix: pylint --- src/sagemaker/jumpstart/filters.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index dc19157ffc..b045435ed0 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -449,6 +449,7 @@ def parse_filter_string(filter_string: str) -> ModelFilter: def _negate_boolean(boolean: BooleanValues) -> BooleanValues: + """Negates boolean expression (False -> True, True -> False).""" if boolean == BooleanValues.TRUE: return BooleanValues.FALSE if boolean == BooleanValues.FALSE: @@ -460,6 +461,7 @@ def _evaluate_filter_expression_equals( model_filter: ModelFilter, cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: + """Evaluates filter expressions for equals.""" if cached_model_value is None: return BooleanValues.FALSE model_filter_value = model_filter.value @@ -475,6 +477,7 @@ def _evaluate_filter_expression_in( model_filter: ModelFilter, cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: + """Evaluates filter expressions for string/list in.""" if cached_model_value is None: return BooleanValues.FALSE py_obj = model_filter.value @@ -484,7 +487,7 @@ def _evaluate_filter_expression_in( iter(py_obj) except TypeError: return BooleanValues.FALSE - except Exception: + except Exception: # pylint: disable=W0703 pass if isinstance(cached_model_value, list): return BooleanValues.FALSE @@ -497,6 +500,7 @@ def _evaluate_filter_expression_includes( model_filter: ModelFilter, cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: + """Evaluates filter expressions for string includes.""" if cached_model_value is None: return BooleanValues.FALSE filter_value = str(model_filter.value) @@ -509,6 +513,7 @@ def _evaluate_filter_expression_begins_with( model_filter: ModelFilter, cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: + """Evaluates filter expressions for string begins with.""" if cached_model_value is None: return BooleanValues.FALSE filter_value = str(model_filter.value) @@ -521,6 +526,7 @@ def _evaluate_filter_expression_ends_with( model_filter: ModelFilter, cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]], ) -> BooleanValues: + """Evaluates filter expressions for string ends with.""" if cached_model_value is None: return BooleanValues.FALSE filter_value = str(model_filter.value) From e2daefc839acb1e7bb1fe1c59347c4acb1595e33 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 29 Dec 2023 02:16:59 +0000 Subject: [PATCH 09/12] feat: private util for model eula key --- src/sagemaker/jumpstart/notebook_utils.py | 26 ++++++++++++++ .../jumpstart/test_notebook_utils.py | 34 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index ecbef8cb3f..7cbdae410b 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -469,3 +469,29 @@ def get_model_url( s3_client=sagemaker_session.s3_client, ) return model_specs.url + + +def _get_model_eula_key( + model_id: str, + model_version: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Retrieve S3 key for EULA text for gated models, or None for non-gated models. + + Args: + model_id (str): The model ID for which to retrieve the EULA S3 key. + model_version (str): The model version for which to retrieve the EULA S3 key. + region (str): Optional. The region from which to retrieve metadata. + (Default: JUMPSTART_DEFAULT_REGION_NAME) + sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use + to retrieve the EULA S3 key. + """ + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=model_version, + s3_client=sagemaker_session.s3_client, + ) + return model_specs.hosting_eula_key diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 8aae4c36a8..202f58ed69 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -15,9 +15,11 @@ get_header_from_base_header, get_prototype_manifest, get_prototype_model_spec, + get_special_model_spec, ) from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, + _get_model_eula_key, get_model_url, list_jumpstart_frameworks, list_jumpstart_models, @@ -698,3 +700,35 @@ def test_get_model_url( region=region, s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test__get_model_eula_key( + patched_get_model_specs: Mock, +): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id, version = "gated_llama_neuron_model", "*" + assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version) + + model_id, version = "variant-model", "1.0.0" + assert None == _get_model_eula_key(model_id, version) + + region = "fake-region" + + patched_get_model_specs.reset_mock() + patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_special_model_spec( + *largs, + region="us-west-2", + **{key: value for key, value in kwargs.items() if key != "region"}, + ) + + _get_model_eula_key(model_id, version, region=region) + + patched_get_model_specs.assert_called_once_with( + model_id=model_id, + version=version, + region=region, + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + ) From 430a517575a84d152a7fe6f3bfb50d170ca574f6 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 29 Dec 2023 18:58:46 +0000 Subject: [PATCH 10/12] fix: unit tests, use verify_model_region_and_return_specs in notebook utils --- src/sagemaker/jumpstart/notebook_utils.py | 21 ++++++++++++------- .../jumpstart/test_notebook_utils.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 7cbdae410b..0237d99dcf 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -34,7 +34,11 @@ ) from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + get_sagemaker_version, + verify_model_region_and_return_specs, +) from sagemaker.session import Session MAX_SEARCH_WORKERS = int(100 * 1e6 / 25 * 1e3) # max 100MB total memory, 25kB per thread) @@ -221,11 +225,12 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter=filter, region=region, sagemaker_session=sagemaker_session ): scripts.add(JumpStartScriptScope.INFERENCE) - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=version, - s3_client=sagemaker_session.s3_client, + sagemaker_session=sagemaker_session, + scope=JumpStartScriptScope.INFERENCE, ) if model_specs.training_supported: scripts.add(JumpStartScriptScope.TRAINING) @@ -462,11 +467,12 @@ def get_model_url( to retrieve the model url. """ - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=model_version, - s3_client=sagemaker_session.s3_client, + sagemaker_session=sagemaker_session, + scope=JumpStartScriptScope.INFERENCE, ) return model_specs.url @@ -488,10 +494,11 @@ def _get_model_eula_key( to retrieve the EULA S3 key. """ - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=model_version, - s3_client=sagemaker_session.s3_client, + sagemaker_session=sagemaker_session, + scope=JumpStartScriptScope.INFERENCE, ) return model_specs.hosting_eula_key diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 202f58ed69..95ad4a5b67 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -713,7 +713,7 @@ def test__get_model_eula_key( assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version) model_id, version = "variant-model", "1.0.0" - assert None == _get_model_eula_key(model_id, version) + assert None is _get_model_eula_key(model_id, version) region = "fake-region" From ad9016a6ece55ee75b8ae31d469762724cd9ffd4 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 2 Jan 2024 14:13:51 +0000 Subject: [PATCH 11/12] Revert "feat: private util for model eula key" This reverts commit e2daefc839acb1e7bb1fe1c59347c4acb1595e33. --- src/sagemaker/jumpstart/notebook_utils.py | 27 --------------- .../jumpstart/test_notebook_utils.py | 34 ------------------- 2 files changed, 61 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 0237d99dcf..1554025995 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -475,30 +475,3 @@ def get_model_url( scope=JumpStartScriptScope.INFERENCE, ) return model_specs.url - - -def _get_model_eula_key( - model_id: str, - model_version: str, - region: str = JUMPSTART_DEFAULT_REGION_NAME, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Retrieve S3 key for EULA text for gated models, or None for non-gated models. - - Args: - model_id (str): The model ID for which to retrieve the EULA S3 key. - model_version (str): The model version for which to retrieve the EULA S3 key. - region (str): Optional. The region from which to retrieve metadata. - (Default: JUMPSTART_DEFAULT_REGION_NAME) - sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use - to retrieve the EULA S3 key. - """ - - model_specs = verify_model_region_and_return_specs( - region=region, - model_id=model_id, - version=model_version, - sagemaker_session=sagemaker_session, - scope=JumpStartScriptScope.INFERENCE, - ) - return model_specs.hosting_eula_key diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 95ad4a5b67..8aae4c36a8 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -15,11 +15,9 @@ get_header_from_base_header, get_prototype_manifest, get_prototype_model_spec, - get_special_model_spec, ) from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, - _get_model_eula_key, get_model_url, list_jumpstart_frameworks, list_jumpstart_models, @@ -700,35 +698,3 @@ def test_get_model_url( region=region, s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, ) - - -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test__get_model_eula_key( - patched_get_model_specs: Mock, -): - - patched_get_model_specs.side_effect = get_special_model_spec - - model_id, version = "gated_llama_neuron_model", "*" - assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version) - - model_id, version = "variant-model", "1.0.0" - assert None is _get_model_eula_key(model_id, version) - - region = "fake-region" - - patched_get_model_specs.reset_mock() - patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_special_model_spec( - *largs, - region="us-west-2", - **{key: value for key, value in kwargs.items() if key != "region"}, - ) - - _get_model_eula_key(model_id, version, region=region) - - patched_get_model_specs.assert_called_once_with( - model_id=model_id, - version=version, - region=region, - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, - ) From a0853f8b4f062d272d65cf42ac95453f2412d68b Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 5 Jan 2024 19:36:07 +0000 Subject: [PATCH 12/12] chore: add search keywords to header --- src/sagemaker/jumpstart/types.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 21b624d7a4..49d3e295c5 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -130,7 +130,7 @@ def __init__( class JumpStartModelHeader(JumpStartDataHolderType): """Data class JumpStart model header.""" - __slots__ = ["model_id", "version", "min_version", "spec_key"] + __slots__ = ["model_id", "version", "min_version", "spec_key", "search_keywords"] def __init__(self, header: Dict[str, str]): """Initializes a JumpStartModelHeader object from its json representation. @@ -142,7 +142,11 @@ def __init__(self, header: Dict[str, str]): def to_json(self) -> Dict[str, str]: """Returns json representation of JumpStartModelHeader object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if getattr(self, att, None) is not None + } return json_obj def from_json(self, json_obj: Dict[str, str]) -> None: @@ -155,6 +159,7 @@ def from_json(self, json_obj: Dict[str, str]) -> None: self.version: str = json_obj["version"] self.min_version: str = json_obj["min_version"] self.spec_key: str = json_obj["spec_key"] + self.search_keywords: Optional[List[str]] = json_obj.get("search_keywords") class JumpStartECRSpecs(JumpStartDataHolderType):