From 567c864cf58bac5883c508bca5e8656000e14783 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 10 Mar 2025 16:33:04 +0000 Subject: [PATCH 1/4] remove s3 output location requirement from hub class init --- src/sagemaker/jumpstart/hub/hub.py | 55 +++--------------- src/sagemaker/jumpstart/hub/utils.py | 57 ------------------- .../unit/sagemaker/jumpstart/hub/test_hub.py | 32 ++++------- .../sagemaker/jumpstart/hub/test_utils.py | 41 ------------- 4 files changed, 19 insertions(+), 166 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 402b2ce534..db5b9a7e70 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -16,15 +16,11 @@ from datetime import datetime import logging from typing import Optional, Dict, List, Any, Union -from botocore import exceptions from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session -from sagemaker.jumpstart.constants import ( - JUMPSTART_LOGGER, -) from sagemaker.jumpstart.types import ( HubContentType, ) @@ -32,9 +28,6 @@ from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, - create_hub_bucket_if_it_does_not_exist, - generate_default_hub_bucket_name, - create_s3_object_reference_from_uri, construct_hub_arn_from_name, ) @@ -42,9 +35,6 @@ list_jumpstart_models, ) -from sagemaker.jumpstart.hub.types import ( - S3ObjectLocation, -) from sagemaker.jumpstart.hub.interfaces import ( DescribeHubResponse, DescribeHubContentResponse, @@ -78,41 +68,11 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name + self.bucket_name = bucket_name self._sagemaker_session = ( sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) ) - self.hub_storage_location = self._generate_hub_storage_location(bucket_name) - - def _fetch_hub_bucket_name(self) -> str: - """Retrieves hub bucket name from Hub config if exists""" - try: - hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") - if hub_output_location: - location = create_s3_object_reference_from_uri(hub_output_location) - return location.bucket - default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - default_bucket_name, - ) - return default_bucket_name - except exceptions.ClientError: - hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - hub_bucket_name, - ) - return hub_bucket_name - - def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: - """Generates an ``S3ObjectLocation`` given a Hub name.""" - hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() - curr_timestamp = datetime.now().timestamp() - return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") def _get_latest_model_version(self, model_id: str) -> str: """Populates the lastest version of a model from specs no matter what is passed. @@ -132,17 +92,20 @@ def create( tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" - - create_hub_bucket_if_it_does_not_exist( - self.hub_storage_location.bucket, self._sagemaker_session - ) + curr_timestamp = datetime.now().timestamp() return self._sagemaker_session.create_hub( hub_name=self.hub_name, hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, + s3_storage_config={ + "S3OutputPath": ( + f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}" + if self.bucket_name + else None + ) + }, tags=tags, ) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 1bbc6198a2..7e00694cb2 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -15,8 +15,6 @@ from __future__ import absolute_import import re from typing import Optional, List, Any -from sagemaker.jumpstart.hub.types import S3ObjectLocation -from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo @@ -138,61 +136,6 @@ def generate_hub_arn_for_init_kwargs( return hub_arn -def generate_default_hub_bucket_name( - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. - - Returns: - str: The name of the default bucket. If the name was not explicitly specified through - the Session or sagemaker_config, the bucket will take the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - account_id: str = sagemaker_session.account_id() - - # TODO: Validate and fast fail - - return f"sagemaker-hubs-{region}-{account_id}" - - -def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: - """Utiity to help generate an S3 object reference""" - if not s3_uri: - return None - - bucket, key = parse_s3_url(s3_uri) - - return S3ObjectLocation( - bucket=bucket, - key=key, - ) - - -def create_hub_bucket_if_it_does_not_exist( - bucket_name: Optional[str] = None, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Creates the default SageMaker Hub bucket if it does not exist. - - Returns: - str: The name of the default bucket. Takes the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - if bucket_name is None: - bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) - - sagemaker_session._create_s3_bucket_if_it_does_not_exist( - bucket_name=bucket_name, - region=region, - ) - - return bucket_name - - def is_gated_bucket(bucket_name: str) -> bool: """Returns true if the bucket name is the JumpStart gated bucket.""" return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 06f5473322..c492207b89 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -16,7 +16,6 @@ import pytest from mock import Mock from sagemaker.jumpstart.hub.hub import Hub -from sagemaker.jumpstart.hub.types import S3ObjectLocation REGION = "us-east-1" @@ -60,48 +59,35 @@ def test_instantiates(sagemaker_session): @pytest.mark.parametrize( - ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + ("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"), [ - pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None), pytest.param( "MockHub2", "this is my sagemaker hub two", - None, "DisplayMockHub2", ["mock", "hub", "123"], [{"Key": "tag-key-1", "Value": "tag-value-1"}], ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") def test_create_with_no_bucket_name( - mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, - hub_bucket_name, hub_display_name, hub_search_keywords, tags, ): - storage_location = S3ObjectLocation( - "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" - ) - mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - sagemaker_session.describe_hub.return_value = { - "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} - } hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": { - "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" - }, + "s3_storage_config": {"S3OutputPath": None}, "tags": tags, } response = hub.create( @@ -128,9 +114,9 @@ def test_create_with_no_bucket_name( ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +@patch("sagemaker.jumpstart.hub.hub.datetime") def test_create_with_bucket_name( - mock_generate_hub_storage_location, + mock_datetime, sagemaker_session, hub_name, hub_description, @@ -139,8 +125,8 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): - storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") - mock_generate_hub_storage_location.return_value = storage_location + mock_datetime.now.return_value = FAKE_TIME + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) @@ -149,7 +135,9 @@ def test_create_with_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "s3_storage_config": { + "S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}" + }, "tags": tags, } response = hub.create( diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index a0b824fc9b..5745a7f79c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -173,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs(): assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn -def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" - # Mock custom session with custom values - mock_custom_session = Mock() - mock_custom_session.account_id.return_value = "000000000000" - mock_custom_session.boto_region_name = "us-east-2" - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn - - def test_is_gated_bucket(): assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True @@ -207,23 +183,6 @@ def test_is_gated_bucket(): assert utils.is_gated_bucket("") is False -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - - @patch("sagemaker.session.Session") def test_get_hub_model_version_success(mock_session): hub_name = "test_hub" From 49f0248db4d98db335cf2b48f99cf302fe88b691 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 10 Mar 2025 17:47:09 +0000 Subject: [PATCH 2/4] fix integ test hub --- src/sagemaker/jumpstart/hub/hub.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index db5b9a7e70..4a412f81d3 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -56,8 +56,8 @@ class Hub: def __init__( self, hub_name: str, + sagemaker_session: Session, bucket_name: Optional[str] = None, - sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates a SageMaker ``Hub``. @@ -94,20 +94,22 @@ def create( """Creates a hub with the given description""" curr_timestamp = datetime.now().timestamp() - return self._sagemaker_session.create_hub( - hub_name=self.hub_name, - hub_description=description, - hub_display_name=display_name, - hub_search_keywords=search_keywords, - s3_storage_config={ + request = { + "hub_name": self.hub_name, + "hub_description": description, + "hub_display_name": display_name, + "hub_search_keywords": search_keywords, + "tags": tags + } + + if self.bucket_name: + request["s3_storage_config"] = { "S3OutputPath": ( f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}" - if self.bucket_name - else None ) - }, - tags=tags, - ) + } + + return self._sagemaker_session.create_hub(**request) def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" From 1e23db4ec97e7d53ee0b88e33a25fb460c317e94 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 10 Mar 2025 17:55:28 +0000 Subject: [PATCH 3/4] lint --- src/sagemaker/jumpstart/hub/hub.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 4a412f81d3..692966cee4 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -99,14 +99,12 @@ def create( "hub_description": description, "hub_display_name": display_name, "hub_search_keywords": search_keywords, - "tags": tags + "tags": tags, } - + if self.bucket_name: request["s3_storage_config"] = { - "S3OutputPath": ( - f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}" - ) + "S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}") } return self._sagemaker_session.create_hub(**request) From 9bc4a4bf9b00b7def0b3ce9e020b87a9c8c17c51 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 10 Mar 2025 20:17:28 +0000 Subject: [PATCH 4/4] fix test --- tests/unit/sagemaker/jumpstart/hub/test_hub.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index c492207b89..29efb6b31f 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -87,7 +87,6 @@ def test_create_with_no_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": None}, "tags": tags, } response = hub.create(