From f7677d853dfc1753facaf34f1cf5148ef5a07925 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:42:55 -0700 Subject: [PATCH 01/17] fix: Move sagemaker pysdk version check after bootstrap in remote job (#4487) --- .../runtime_environment/bootstrap_runtime_environment.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 5332f7bdd0..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": From 347b59919d61b2be5cad5997b23e35f15e65b758 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:08:33 -0400 Subject: [PATCH 02/17] feat: support JumpStart proprietary models (#4467) * feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> --- src/sagemaker/jumpstart/cache.py | 1 - src/sagemaker/jumpstart/types.py | 1 + .../sagemaker/instance_types/jumpstart/test_instance_types.py | 1 + tests/unit/sagemaker/jumpstart/utils.py | 1 - tests/unit/sagemaker/script_uris/jumpstart/test_common.py | 1 + 5 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8d0f1832bf..5622b3f01e 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -429,7 +429,6 @@ def _retrieval_function( """ data_type, id_info = key.data_type, key.id_info - if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 01622a5462..897976f241 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1434,6 +1434,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "model_type", "hub_arn", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index f3454ca322..77d98d5437 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -126,6 +126,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index ad093640b7..4c5fdf2ab2 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -31,7 +31,6 @@ HubContentType, ) from sagemaker.jumpstart.enums import JumpStartModelType - from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index e1d3ef6ae1..4be8027046 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() From 2973f23a4914de937e87c59b18aa57cc43f8f595 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 21 Feb 2024 10:58:01 -0500 Subject: [PATCH 03/17] feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438) --- src/sagemaker/jumpstart/cache.py | 2 ++ src/sagemaker/jumpstart/constants.py | 3 +- src/sagemaker/jumpstart/types.py | 19 +++++++++++++ src/sagemaker/jumpstart/utils.py | 1 + tests/unit/sagemaker/jumpstart/test_utils.py | 29 ++++++++++++++++++++ tests/unit/sagemaker/jumpstart/utils.py | 1 + 6 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 5622b3f01e..d831d3023b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -58,6 +58,7 @@ DescribeHubContentsResponse, HubType, HubContentType, + HubDataType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.enums import JumpStartModelType @@ -429,6 +430,7 @@ def _retrieval_function( """ data_type, id_info = key.data_type, key.id_info + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 1b679d44f6..21412d65ea 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -172,7 +172,8 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" -HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +# works cross-partition +HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 897976f241..0cc67b72ce 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -972,6 +972,25 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartModelSpecs object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4fc8752625..dd0deb1291 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -16,6 +16,7 @@ import os import re from typing import Any, Dict, List, Set, Optional, Tuple, Union +import re from urllib.parse import urlparse import boto3 from packaging.version import Version diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fa1a3fc72f..472b2dfdd9 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1214,6 +1214,35 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type +def test_extract_info_from_hub_content_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.extract_info_from_hub_content_arn(model_arn) == ( + "MockHub", + "us-west-2", + "my-mock-model", + "1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "nonsense-string" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" + ) + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 4c5fdf2ab2..da6227cdf0 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,6 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( + HubDataType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, From 5bc742f3a17101ce32f1e960208a398d21901779 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Mon, 26 Feb 2024 14:18:18 -0500 Subject: [PATCH 04/17] feat: jsch jumpstart estimator support (#4439) --- src/sagemaker/jumpstart/accessors.py | 1 + src/sagemaker/jumpstart/cache.py | 1 + src/sagemaker/jumpstart/constants.py | 3 +- src/sagemaker/jumpstart/estimator.py | 1 + src/sagemaker/jumpstart/factory/estimator.py | 2 ++ src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 8 +++++ src/sagemaker/jumpstart/utils.py | 1 + .../jumpstart/test_validate.py | 2 ++ .../image_uris/jumpstart/test_common.py | 4 +++ .../jumpstart/test_instance_types.py | 1 - .../sagemaker/jumpstart/test_accessors.py | 1 - .../jumpstart/test_notebook_utils.py | 1 + tests/unit/sagemaker/jumpstart/test_utils.py | 29 ------------------- tests/unit/sagemaker/jumpstart/utils.py | 5 +++- .../model_uris/jumpstart/test_common.py | 4 +++ .../jumpstart/test_resource_requirements.py | 1 + .../script_uris/jumpstart/test_common.py | 4 +++ 18 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index dfc833ec28..c9f805c225 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -257,6 +257,7 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d831d3023b..7bd9fa10ba 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -40,6 +40,7 @@ get_wildcard_model_version_msg, get_wildcard_proprietary_model_version_msg, ) +from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 21412d65ea..1b679d44f6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -172,8 +172,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" -# works cross-partition -HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 6406932924..4cdd540111 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook(): model_version=model_version, hub_arn=hub_arn, model_type=self.model_type, + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index fb598256fa..fa04b46a7c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -81,6 +81,7 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -140,6 +141,7 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 6f7a83cef1..273257088e 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -550,6 +550,7 @@ def get_deploy_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -584,6 +585,7 @@ def get_deploy_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0cc67b72ce..8a17d8de88 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1454,6 +1455,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_type", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1500,6 +1502,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1536,6 +1539,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "instance_type", "instance_count", "region", @@ -1597,6 +1601,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", } def __init__( @@ -1726,6 +1731,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "inputs", "wait", @@ -1742,6 +1748,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1770,6 +1777,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index dd0deb1291..44294d67f9 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -17,6 +17,7 @@ import re from typing import Any, Dict, List, Set, Optional, Tuple, Union import re +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0f69cb572a..93d7098870 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index bd4383499d..45af6faeed 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 77d98d5437..f3454ca322 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -126,7 +126,6 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 5d527dd5a1..24945bde22 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -137,7 +137,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): > 0 ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_jumpstart_models_cache_get_model_specs(mock_cache): mock_cache.get_specs = Mock() diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index a5d1ee3ac2..ed7c870a0e 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -751,4 +751,5 @@ def test_get_model_url( s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 472b2dfdd9..fa1a3fc72f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_content_arn(): - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" - ) - assert utils.extract_info_from_hub_content_arn(model_arn) == ( - "MockHub", - "us-west-2", - "my-mock-model", - "1.0.2", - ) - - hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) - - invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "nonsense-string" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" - ) - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index da6227cdf0..79f805444f 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -253,6 +253,9 @@ def patched_retrieval_function( model_type=JumpStartModelType.PROPRIETARY, ) ) + # TODO: Implement + if datatype == HubContentType.HUB: + return None if datatype == HubContentType.MODEL: _, _, _, model_name, model_version = id_info.split("/") diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 2bb327c26f..06587a2074 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 2a4d913a75..b2e055dd3c 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 4be8027046..99cec92e99 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -55,6 +55,7 @@ def test_jumpstart_common_script_uri( model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,6 +75,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -94,6 +96,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -114,6 +117,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() From 9210e494da31fa5adc4b41a408426eac3c0896fd Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 28 Feb 2024 17:09:01 -0500 Subject: [PATCH 05/17] Master jumpstart curated hub (#4464) --- .../runtime_environment/bootstrap_runtime_environment.py | 3 +++ .../runtime_environment/runtime_environment_manager.py | 2 ++ src/sagemaker/utils.py | 1 + 3 files changed, 6 insertions(+) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 8fd83bfcfe..5332f7bdd0 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,6 +65,9 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) user = getpass.getuser() if user != "root": diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 13493c1d15..64e6c087f8 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,6 +24,8 @@ import dataclasses import json +import sagemaker + class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 7896aac150..fe8c0b7c56 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,6 +22,7 @@ import random import re import shutil +import sys import tarfile import tempfile import time From 4905ceee3ed65f863d27442cc20ef0e249c2591c Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 28 Feb 2024 17:15:59 -0500 Subject: [PATCH 06/17] add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (#4463) --- src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 273257088e..2fa538f33b 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -726,6 +726,7 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -759,6 +760,7 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8a17d8de88..8df0796a22 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1310,6 +1310,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1342,6 +1343,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1355,6 +1357,7 @@ def __init__( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1386,6 +1389,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.instance_type = instance_type self.region = region self.image_uri = image_uri From 82d0d926edc0943099a9d7b010f903d7cdca2130 Mon Sep 17 00:00:00 2001 From: Jinyoung Lim Date: Thu, 29 Feb 2024 08:42:47 -0800 Subject: [PATCH 07/17] feature: JumpStart CuratedHub class creation and function definitions (#4448) --- src/sagemaker/jumpstart/cache.py | 1 - src/sagemaker/jumpstart/types.py | 19 ----------- .../jumpstart/curated_hub/test_utils.py | 32 +++++++++++++++++++ tests/unit/sagemaker/jumpstart/test_cache.py | 3 +- tests/unit/sagemaker/jumpstart/utils.py | 4 +-- 5 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 7bd9fa10ba..f269e22d53 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -480,7 +480,6 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubType.HUB: hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8df0796a22..d0fe203037 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -972,25 +972,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartModelSpecs object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index b4b2eaabb2..24212b5f68 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -139,6 +139,38 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) + == hub_arn + ) + + +def test_generate_default_hub_bucket_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-east-1" + + assert ( + utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) + == "sagemaker-hubs-us-east-1-123456789123" + ) + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index d5537712a0..7f66b495be 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -28,6 +28,7 @@ JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JumpStartModelsCache, ) +from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_not_called() + assert mocked_open.call_count == 2 mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 79f805444f..bd1bc8d691 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,6 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -32,6 +31,7 @@ HubContentType, ) from sagemaker.jumpstart.enums import JumpStartModelType + from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, @@ -254,7 +254,7 @@ def patched_retrieval_function( ) ) # TODO: Implement - if datatype == HubContentType.HUB: + if datatype == HubType.HUB: return None if datatype == HubContentType.MODEL: From 4cef235d6ddac05a550860520d24112637326e0d Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Tue, 12 Mar 2024 13:15:58 -0400 Subject: [PATCH 08/17] MultiPartCopy with Sync Algorithm (#4475) * first pass at sync function with util classes * adding tests and update clases * linting * file generator class inheritance * lint * multipart copy and algorithm updates * modularize sync * reformatting folders * testing for sync * do not tolerate vulnerable * remove prints * handle multithreading progress bar * update tests * optimize function and add hub bucket prefix * docstrings and linting --- .../sagemaker/jumpstart/curated_hub/test_utils.py | 11 +++++++---- tests/unit/sagemaker/jumpstart/test_cache.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 24212b5f68..6743a969e2 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -139,10 +139,7 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn def test_generate_default_hub_bucket_name(): @@ -162,8 +159,14 @@ def test_create_hub_bucket_if_it_does_not_exist(): mock_sagemaker_session.client("sts").get_caller_identity.return_value = { "Account": "123456789123" } + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( sagemaker_session=mock_sagemaker_session diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 7f66b495be..1f3d685cd9 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -1134,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - assert mocked_open.call_count == 2 + mocked_open.assert_not_called() mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), From cb81d11285ceb66054d34daae95b4661ca42212f Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 13 Mar 2024 16:06:53 +0000 Subject: [PATCH 09/17] rebase with master --- src/sagemaker/jumpstart/cache.py | 2 +- src/sagemaker/jumpstart/types.py | 1 - .../runtime_environment/runtime_environment_manager.py | 2 -- src/sagemaker/utils.py | 1 - tests/unit/sagemaker/jumpstart/test_accessors.py | 1 + tests/unit/sagemaker/jumpstart/test_cache.py | 1 - tests/unit/sagemaker/jumpstart/utils.py | 7 +++++++ 7 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index f269e22d53..d831d3023b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -40,7 +40,6 @@ get_wildcard_model_version_msg, get_wildcard_proprietary_model_version_msg, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -480,6 +479,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) + if data_type == HubType.HUB: hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index d0fe203037..e46850d139 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,7 +15,6 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 64e6c087f8..13493c1d15 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,8 +24,6 @@ import dataclasses import json -import sagemaker - class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index fe8c0b7c56..7896aac150 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,7 +22,6 @@ import random import re import shutil -import sys import tarfile import tempfile import time diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 24945bde22..5d527dd5a1 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -137,6 +137,7 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): > 0 ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_jumpstart_models_cache_get_model_specs(mock_cache): mock_cache.get_specs = Mock() diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 1f3d685cd9..d5537712a0 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -28,7 +28,6 @@ JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JumpStartModelsCache, ) -from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index bd1bc8d691..3451334a70 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -253,6 +253,13 @@ def patched_retrieval_function( model_type=JumpStartModelType.PROPRIETARY, ) ) + + if datatype == HubContentType.MODEL: + _, _, _, model_name, model_version = id_info.split("/") + return JumpStartCachedContentValue( + formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) + ) + # TODO: Implement if datatype == HubType.HUB: return None From f50de6b406d2b75fcacc4349bbdacc362e42416f Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 13 Mar 2024 20:43:29 +0000 Subject: [PATCH 10/17] bad rebase --- src/sagemaker/jumpstart/accessors.py | 1 - src/sagemaker/jumpstart/cache.py | 1 - src/sagemaker/jumpstart/estimator.py | 1 - src/sagemaker/jumpstart/factory/estimator.py | 2 -- src/sagemaker/jumpstart/factory/model.py | 4 ---- src/sagemaker/jumpstart/types.py | 12 +----------- .../hyperparameters/jumpstart/test_validate.py | 2 -- .../sagemaker/image_uris/jumpstart/test_common.py | 4 ---- .../unit/sagemaker/jumpstart/test_notebook_utils.py | 1 - .../sagemaker/model_uris/jumpstart/test_common.py | 4 ---- .../jumpstart/test_resource_requirements.py | 1 - .../sagemaker/script_uris/jumpstart/test_common.py | 3 --- 12 files changed, 1 insertion(+), 35 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index c9f805c225..dfc833ec28 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -257,7 +257,6 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d831d3023b..8d0f1832bf 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -58,7 +58,6 @@ DescribeHubContentsResponse, HubType, HubContentType, - HubDataType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.enums import JumpStartModelType diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4cdd540111..6406932924 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -534,7 +534,6 @@ def _validate_model_id_and_get_type_hook(): model_version=model_version, hub_arn=hub_arn, model_type=self.model_type, - hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index fa04b46a7c..fb598256fa 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -81,7 +81,6 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -141,7 +140,6 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 2fa538f33b..6f7a83cef1 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -550,7 +550,6 @@ def get_deploy_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -585,7 +584,6 @@ def get_deploy_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -726,7 +724,6 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -760,7 +757,6 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e46850d139..c6278d7dfa 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,6 +15,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics @@ -1290,7 +1291,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1323,7 +1323,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1337,7 +1336,6 @@ def __init__( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1369,7 +1367,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.instance_type = instance_type self.region = region self.image_uri = image_uri @@ -1404,7 +1401,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1486,7 +1482,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1523,7 +1518,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "instance_type", "instance_count", "region", @@ -1585,7 +1579,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", } def __init__( @@ -1715,7 +1708,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "region", "inputs", "wait", @@ -1732,7 +1724,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1761,7 +1752,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 93d7098870..0f69cb572a 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -453,7 +453,6 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -517,7 +516,6 @@ def test_jumpstart_validate_all_hyperparameters( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 45af6faeed..bd4383499d 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -56,7 +56,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -79,7 +78,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -102,7 +100,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -125,7 +122,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index ed7c870a0e..a5d1ee3ac2 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -751,5 +751,4 @@ def test_get_model_url( s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 06587a2074..2bb327c26f 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,7 +54,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,7 +73,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,7 +93,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,7 +113,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index b2e055dd3c..2a4d913a75 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,7 +57,6 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 99cec92e99..d332b22c2c 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -75,7 +75,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +95,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -117,7 +115,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() From 2eff8fbe9ac539e3994e52fdcf5774086d3f3748 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 15 Mar 2024 15:32:20 +0000 Subject: [PATCH 11/17] support for gated and training unsupported --- .../curated_hub/accessors/file_generator.py | 5 ++ .../accessors/public_model_data.py | 52 ++++++++++++------- .../jumpstart/curated_hub/curated_hub.py | 1 - src/sagemaker/jumpstart/curated_hub/utils.py | 10 +++- src/sagemaker/jumpstart/utils.py | 1 + .../curated_hub/test_filegenerator.py | 41 +++++++++++++++ .../jumpstart/curated_hub/test_utils.py | 10 ++++ .../sagemaker/jumpstart/test_accessors.py | 1 + 8 files changed, 101 insertions(+), 20 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index 0393b4234a..e5ea072d86 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -22,6 +22,7 @@ S3ObjectLocation, ) from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor +from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -65,6 +66,10 @@ def generate_file_infos_from_model_specs( files = [] for dependency in HubContentDependencyType: location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + # Training dependencies will return as None if training is unsupported + if not location or is_gated_bucket(location.bucket): + continue + location_type = "prefix" if location.key.endswith("/") else "object" if location_type == "prefix": diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 89e3a2f108..604c0f2063 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """This module accessors for the SageMaker JumpStart Public Hub.""" from __future__ import absolute_import -from typing import Dict, Any +from typing import Dict, Any, Optional from sagemaker import model_uris, script_uris from sagemaker.jumpstart.curated_hub.types import ( HubContentDependencyType, @@ -21,7 +21,7 @@ from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartModelSpecs -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket class PublicModelDataAccessor: @@ -34,7 +34,11 @@ def __init__( studio_specs: Dict[str, Dict[str, Any]], ): self._region = region - self._bucket = get_jumpstart_content_bucket(region) + self._bucket = ( + get_jumpstart_gated_content_bucket(region) + if model_specs.gated_bucket + else get_jumpstart_content_bucket(region) + ) self.model_specs = model_specs self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift @@ -52,6 +56,8 @@ def inference_artifact_s3_reference(self): @property def training_artifact_s3_reference(self): """Retrieves s3 reference for model training artifact""" + if not self.model_specs.training_supported: + return None return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) ) @@ -66,6 +72,8 @@ def inference_script_s3_reference(self): @property def training_script_s3_reference(self): """Retrieves s3 reference for model training script""" + if not self.model_specs.training_supported: + return None return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) ) @@ -73,6 +81,8 @@ def training_script_s3_reference(self): @property def default_training_dataset_s3_reference(self): """Retrieves s3 reference for s3 directory containing model training datasets""" + if not self.model_specs.training_supported: + return None return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) @property @@ -95,22 +105,28 @@ def _get_bucket_name(self) -> str: def __get_training_dataset_prefix(self) -> str: """Retrieves training dataset location""" - return self.studio_specs["defaultDataKey"] + return self.studio_specs.get("defaultDataKey") - def _jumpstart_script_s3_uri(self, model_scope: str) -> str: + def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]: """Retrieves JumpStart script s3 location""" - return script_uris.retrieve( - region=self._region, - model_id=self.model_specs.model_id, - model_version=self.model_specs.version, - script_scope=model_scope, - ) + try: + return script_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + script_scope=model_scope, + ) + except ValueError: + return None - def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: + def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]: """Retrieves JumpStart artifact s3 location""" - return model_uris.retrieve( - region=self._region, - model_id=self.model_specs.model_id, - model_version=self.model_specs.version, - model_scope=model_scope, - ) + try: + return model_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + model_scope=model_scope, + ) + except ValueError: + return None diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index a35948f138..a9d07f1bd7 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -394,7 +394,6 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): self._sagemaker_session.import_hub_content( document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, hub_content_name=model.model_id, - hub_content_version=model.version, hub_name=self.hub_name, hub_content_document=hub_content_document, hub_content_type=HubContentType.MODEL, diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index b116411801..71008ab5b4 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -133,8 +133,11 @@ def generate_default_hub_bucket_name( return f"sagemaker-hubs-{region}-{account_id}" -def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: +def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: """Utiity to help generate an S3 object reference""" + if not s3_uri: + return None + bucket, key = parse_s3_url(s3_uri) return S3ObjectLocation( @@ -164,3 +167,8 @@ def create_hub_bucket_if_it_does_not_exist( ) return bucket_name + + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 44294d67f9..a48f81320e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -870,6 +870,7 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str: """Returns the Studio Spec file prefix given a model ID and version.""" return f"studio_models/{model_id}/studio_specs_v{model_version}.json" + def extract_info_from_hub_content_arn( arn: str, ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index accd2a5c8d..675ba312ca 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -127,3 +127,44 @@ def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.assert_called_once() assert response == [] + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client): + specs = Mock() + specs.model_id = "mock_model_123" + specs.training_supported = False + specs.gated_bucket = False + specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz" + specs.hosting_script_key = "/my/inference/script.py" + + response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client) + + assert response == [ + FileInfo( + "jumpstart-cache-prod-us-west-2", + "/my/inference/tarball.tgz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "/my/inference/script.py", + 123456789, + "08-14-1997 00:00:00", + ), + ] + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client): + specs = Mock() + specs.model_id = "mock_model_123" + specs.gated_bucket = True + specs.training_supported = True + specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz" + specs.hosting_script_key = "/my/inference/script.py" + specs.training_prepacked_artifact_key = "/my/training/tarball.tgz" + specs.training_script_key = "/my/training/script.py" + + response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client) + + assert response == [] \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 6743a969e2..c3630df5e6 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -238,3 +238,13 @@ def test_create_hub_bucket_if_it_does_not_exist(): mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() assert created_hub_bucket_name == bucket_name + + +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True + + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False + + assert utils.is_gated_bucket("") is False diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 5d527dd5a1..d5a76a405e 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -98,6 +98,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): ) ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_jumpstart_proprietary_models_cache_get(mock_cache): From b50c5570ece5ecd8d75f446af83678f97de22287 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 13:49:26 +0000 Subject: [PATCH 12/17] merge with master-curated-jumpstart --- src/sagemaker/jumpstart/cache.py | 3 +- .../accessors/public_model_data.py | 5 +- src/sagemaker/jumpstart/utils.py | 4 +- .../curated_hub/test_filegenerator.py | 4 +- .../jumpstart/curated_hub/test_utils.py | 63 ------------------- .../sagemaker/jumpstart/test_accessors.py | 33 +--------- .../script_uris/jumpstart/test_common.py | 2 - 7 files changed, 10 insertions(+), 104 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8d0f1832bf..417bae77c7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -34,7 +34,6 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, MODEL_TYPE_TO_MANIFEST_MAP, MODEL_TYPE_TO_SPECS_MAP, - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, @@ -443,7 +442,7 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 604c0f2063..ba4a56374e 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -21,7 +21,10 @@ from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartModelSpecs -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + get_jumpstart_gated_content_bucket, +) class PublicModelDataAccessor: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a48f81320e..210548511d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -16,8 +16,6 @@ import os import re from typing import Any, Dict, List, Set, Optional, Tuple, Union -import re -from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -876,7 +874,7 @@ def extract_info_from_hub_content_arn( ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: """Extracts hub_name, content_name, and content_version from a HubContentArn""" - match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) if match: hub_name = match.group(4) hub_region = match.group(2) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 675ba312ca..8fcb8dd740 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -128,6 +128,7 @@ def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.assert_called_once() assert response == [] + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client): specs = Mock() @@ -154,6 +155,7 @@ def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_c ), ] + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client): specs = Mock() @@ -167,4 +169,4 @@ def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client): response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client) - assert response == [] \ No newline at end of file + assert response == [] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index c3630df5e6..ac5fdaba3e 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -177,69 +177,6 @@ def test_create_hub_bucket_if_it_does_not_exist(): assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn -def test_generate_default_hub_bucket_name(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.boto_region_name = "us-east-1" - - assert ( - utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) - == "sagemaker-hubs-us-east-1-123456789123" - ) - - -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" - # Mock custom session with custom values - mock_custom_session = Mock() - mock_custom_session.account_id.return_value = "000000000000" - mock_custom_session.boto_region_name = "us-east-2" - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn - - -def test_generate_default_hub_bucket_name(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.boto_region_name = "us-east-1" - - assert ( - utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) - == "sagemaker-hubs-us-east-1-123456789123" - ) - - -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - - def test_is_gated_bucket(): assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index d5a76a405e..3647cc475d 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -83,7 +83,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version ) - mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS) mock_cache.get_hub_model.assert_not_called() accessors.JumpStartModelsAccessor.get_model_specs( @@ -139,37 +139,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): ) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") -def test_jumpstart_models_cache_get_model_specs(mock_cache): - mock_cache.get_specs = Mock() - mock_cache.get_hub_model = Mock() - model_id, version = "pytorch-ic-mobilenet-v2", "*" - region = "us-west-2" - - accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=version - ) - mock_cache.get_specs.assert_called_once_with( - model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS - ) - mock_cache.get_hub_model.assert_not_called() - - accessors.JumpStartModelsAccessor.get_model_specs( - region=region, - model_id=model_id, - version=version, - hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", - ) - mock_cache.get_hub_model.assert_called_once_with( - hub_model_arn=( - f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" - ) - ) - - # necessary because accessors is a static module - reload(accessors) - - @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index d332b22c2c..e1d3ef6ae1 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -54,8 +54,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() From 1af132eb3510fa75a10db45c199f38cecd759af5 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 13:57:06 +0000 Subject: [PATCH 13/17] linting --- .../jumpstart/curated_hub/curated_hub.py | 1 + src/sagemaker/jumpstart/types.py | 2 -- .../sagemaker/jumpstart/test_accessors.py | 4 +++- tests/unit/sagemaker/jumpstart/utils.py | 20 ------------------- 4 files changed, 4 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index a9d07f1bd7..a35948f138 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -394,6 +394,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): self._sagemaker_session.import_hub_content( document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, hub_content_name=model.model_id, + hub_content_version=model.version, hub_name=self.hub_name, hub_content_document=hub_content_document, hub_content_type=HubContentType.MODEL, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c6278d7dfa..01622a5462 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1434,8 +1434,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "model_type", "hub_arn", - "model_type", - "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 3647cc475d..79eeb4b7f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version ) - mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS) + mock_cache.get_specs.assert_called_once_with( + model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS + ) mock_cache.get_hub_model.assert_not_called() accessors.JumpStartModelsAccessor.get_model_specs( diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 3451334a70..410aba4d03 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -254,26 +254,6 @@ def patched_retrieval_function( ) ) - if datatype == HubContentType.MODEL: - _, _, _, model_name, model_version = id_info.split("/") - return JumpStartCachedContentValue( - formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) - ) - - # TODO: Implement - if datatype == HubType.HUB: - return None - - if datatype == HubContentType.MODEL: - _, _, _, model_name, model_version = id_info.split("/") - return JumpStartCachedContentValue( - formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) - ) - - # TODO: Implement - if datatype == HubType.HUB: - return None - raise ValueError(f"Bad value for datatype: {datatype}") From c9f79fd9c390a82da12af8bc03d87ed73ac916dd Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 14:25:34 +0000 Subject: [PATCH 14/17] update types --- .../curated_hub/accessors/public_model_data.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index ba4a56374e..9894ad05e0 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -50,14 +50,14 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType): return getattr(self, dependency_type.value) @property - def inference_artifact_s3_reference(self): + def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model inference artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) ) @property - def training_artifact_s3_reference(self): + def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model training artifact""" if not self.model_specs.training_supported: return None @@ -66,14 +66,14 @@ def training_artifact_s3_reference(self): ) @property - def inference_script_s3_reference(self): + def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model inference script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) ) @property - def training_script_s3_reference(self): + def training_script_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model training script""" if not self.model_specs.training_supported: return None @@ -82,21 +82,21 @@ def training_script_s3_reference(self): ) @property - def default_training_dataset_s3_reference(self): + def default_training_dataset_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for s3 directory containing model training datasets""" if not self.model_specs.training_supported: return None return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) @property - def demo_notebook_s3_reference(self): + def demo_notebook_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model demo jupyter notebook""" framework = self.model_specs.get_framework() key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" return S3ObjectLocation(self._get_bucket_name(), key) @property - def markdown_s3_reference(self): + def markdown_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model markdown""" framework = self.model_specs.get_framework() key = f"{framework}-metadata/{self.model_specs.model_id}.md" @@ -106,7 +106,7 @@ def _get_bucket_name(self) -> str: """Retrieves s3 bucket""" return self._bucket - def __get_training_dataset_prefix(self) -> str: + def _get_training_dataset_prefix(self) -> Optional[str]: """Retrieves training dataset location""" return self.studio_specs.get("defaultDataKey") From 29733c83dfbab87d46c936363fa6b6a37d961895 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 14:46:53 +0000 Subject: [PATCH 15/17] update --- .../jumpstart/curated_hub/accessors/public_model_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 9894ad05e0..a4e339591b 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -86,7 +86,7 @@ def default_training_dataset_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for s3 directory containing model training datasets""" if not self.model_specs.training_supported: return None - return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) + return S3ObjectLocation(self._get_bucket_name(), self._get_training_dataset_prefix()) @property def demo_notebook_s3_reference(self) -> S3ObjectLocation: From 352f1ac7847a43962e34972b751d50a1d07b4b5a Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 16:34:45 +0000 Subject: [PATCH 16/17] update bootstrap --- .../runtime_environment/bootstrap_runtime_environment.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 5332f7bdd0..d5d879cb08 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -89,10 +89,6 @@ def main(sys_args=None): client_python_version, conda_env, dependency_settings ) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - exit_code = SUCCESS_EXIT_CODE except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while bootstrapping runtime environment: %s", e) From aa8e4cbc5d4f6798e98d221b6b72023f2ec8e794 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 18 Mar 2024 17:17:44 +0000 Subject: [PATCH 17/17] fix codecov --- .../runtime_environment/bootstrap_runtime_environment.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index d5d879cb08..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": @@ -89,6 +86,10 @@ def main(sys_args=None): client_python_version, conda_env, dependency_settings ) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + exit_code = SUCCESS_EXIT_CODE except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while bootstrapping runtime environment: %s", e)