diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 03a0ebe545..eff7f53717 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -222,8 +222,6 @@ JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" -JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" - JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c9354e020b..e6952b2154 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -447,6 +447,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, @@ -1036,6 +1037,7 @@ def _get_deployment_configs( image_uri=image_uri, region=self.region, model_version=self.model_version, + hub_arn=self.hub_arn, ) deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, @@ -1043,6 +1045,7 @@ def _get_deployment_configs( sagemaker_session=self.sagemaker_session, region=self.region, model_version=self.model_version, + hub_arn=self.hub_arn, ) deployment_config_metadata = DeploymentConfigMetadata( diff --git a/tests/integ/sagemaker/jumpstart/conftest.py b/tests/integ/sagemaker/jumpstart/conftest.py index c7554f3e51..260b0f2b22 100644 --- a/tests/integ/sagemaker/jumpstart/conftest.py +++ b/tests/integ/sagemaker/jumpstart/conftest.py @@ -16,24 +16,43 @@ import boto3 import pytest from botocore.config import Config +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub from sagemaker.session import Session from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + HUB_NAME_PREFIX, JUMPSTART_TAG, ) +from sagemaker.jumpstart.types import ( + HubContentType, +) + from tests.integ.sagemaker.jumpstart.utils import ( get_test_artifact_bucket, get_test_suite_id, + get_sm_session, + with_exponential_backoff, ) -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME - def _setup(): print("Setting up...") - os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()}) + test_suite_id = get_test_suite_id() + test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}" + test_hub_description = "PySDK Integ Test Private Hub" + + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id}) + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name}) + + # Create a private hub to use for the test session + hub = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + hub.create(description=test_hub_description) def _teardown(): @@ -43,6 +62,8 @@ def _teardown(): test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) sagemaker_client = boto3_session.client( @@ -113,6 +134,29 @@ def _teardown(): bucket = s3_resource.Bucket(test_cache_bucket) bucket.objects.filter(Prefix=test_suite_id + "/").delete() + # delete private hubs + _delete_hubs(sagemaker_session, test_hub_name) + + +def _delete_hubs(sagemaker_session, hub_name): + # list and delete all hub contents first + list_hub_content_response = sagemaker_session.list_hub_contents( + hub_name=hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value + ) + for model in list_hub_content_response["HubContentSummaries"]: + _delete_hub_contents(sagemaker_session, hub_name, model) + + sagemaker_session.delete_hub(hub_name) + + +@with_exponential_backoff() +def _delete_hub_contents(sagemaker_session, hub_name, model): + sagemaker_session.delete_hub_content_reference( + hub_name=hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model["HubContentName"], + ) + @pytest.fixture(scope="session", autouse=True) def setup(request): diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index b839866b1f..27279feecf 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -37,8 +37,11 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID" +ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME = "JUMPSTART_SDK_TEST_HUB_NAME" + JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id" +HUB_NAME_PREFIX = "PySDK-HubTest-" TRAINING_DATASET_MODEL_DICT = { ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 7733041579..ec98786da4 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -229,7 +229,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): @mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") -def test_instatiating_model(mock_warning_logger, setup): +def test_instantiating_model(mock_warning_logger, setup): model_id = "catboost-regression-model" diff --git a/tests/integ/sagemaker/jumpstart/private_hub/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py new file mode 100644 index 0000000000..751162d2e6 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -0,0 +1,171 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs +from sagemaker.predictor import retrieve_default + +import tests.integ + +from sagemaker.jumpstart.model import JumpStartModel +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, +) + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "catboost-classification-model", + "huggingface-txt2img-conflictx-complex-lineart", + "meta-textgeneration-llama-2-7b", + "meta-textgeneration-llama-3-2-1b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + hub_instance.create_model_reference(model_arn=model_arn) + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_model(setup, add_model_references): + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + +def test_jumpstart_hub_gated_model(setup, add_model_references): + + model_id = "meta-textgeneration-llama-3-2-1b" + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + predictor = model.deploy( + accept_eula=True, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + payload = model.retrieve_example_payload() + + response = predictor.predict(payload) + + assert response is not None + + +def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): + + model_id = "meta-textgeneration-llama-2-7b" + + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + region = tests.integ.test_region() + + sagemaker_session = get_sm_session() + + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + ) + + predictor = retrieve_default( + endpoint_name=model.endpoint_name, + sagemaker_session=sagemaker_session, + tolerate_vulnerable_model=True, + hub_arn=hub_arn, + ) + + payload = model.retrieve_example_payload() + + response = predictor.predict(payload) + + assert response is not None + + model = JumpStartModel.attach( + predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name + ) + assert model.model_id == model_id + assert model.endpoint_name == predictor.endpoint_name + assert model.inference_component_name == predictor.component_name + + +def test_instantiating_model(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py new file mode 100644 index 0000000000..2bccb96524 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker.jumpstart.hub.hub import Hub + +from tests.integ.sagemaker.jumpstart.utils import ( + get_sm_session, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_test_suite_id, +) +from tests.integ.sagemaker.jumpstart.constants import ( + HUB_NAME_PREFIX, +) + + +@pytest.fixture +def hub_instance(): + HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}" + hub = Hub(HUB_NAME, sagemaker_session=get_sm_session()) + yield hub + + +def test_private_hub(setup, hub_instance): + # Createhub + create_hub_response = hub_instance.create( + description="This is a Test Private Hub.", + display_name="PySDK integration tests Hub", + search_keywords=["jumpstart-sdk-integ-test"], + ) + + # Create Hub Verifications + assert create_hub_response is not None + + # Describe Hub + hub_description = hub_instance.describe() + assert hub_description is not None + + # Delete Hub + delete_hub_response = hub_instance.delete() + assert delete_hub_response is not None diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py new file mode 100644 index 0000000000..b25cff2d62 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import os +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse +from tests.integ.sagemaker.jumpstart.utils import ( + get_sm_session, +) +from tests.integ.sagemaker.jumpstart.utils import get_public_hub_model_arn +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, +) + + +def test_hub_model_reference(setup): + model_id = "meta-textgenerationneuron-llama-3-2-1b-instruct" + + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + + create_model_response = hub_instance.create_model_reference( + model_arn=get_public_hub_model_arn(hub_instance, model_id) + ) + assert create_model_response is not None + + describe_model_response = hub_instance.describe_model(model_name=model_id) + assert describe_model_response is not None + assert type(describe_model_response) == DescribeHubContentResponse + assert describe_model_response.hub_content_name == model_id + assert describe_model_response.hub_content_type == "ModelReference" + + delete_model_response = hub_instance.delete_model_reference(model_name=model_id) + assert delete_model_response is not None diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 0f2fd01572..47dc1f45d3 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -14,6 +14,8 @@ import functools import json +import random +import time import uuid from typing import Any, Dict, List, Tuple import boto3 @@ -21,6 +23,7 @@ import os from botocore.config import Config +from botocore.exceptions import ClientError import pytest @@ -32,6 +35,7 @@ ) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.hub.hub import Hub from sagemaker.session import Session @@ -115,6 +119,41 @@ def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None: s3_client.download_file(s3_bucket, s3_key, local_download_path) +def get_public_hub_model_arn(hub: Hub, model_id: str) -> str: + filter_value = f"model_id == {model_id}" + response = hub.list_sagemaker_public_hub_models(filter=filter_value) + + models = response["hub_content_summaries"] + + return models[0]["hub_content_arn"] + + +def with_exponential_backoff(max_retries=5, initial_delay=1, max_delay=60): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + retries = 0 + while True: + try: + return func(*args, **kwargs) + except ClientError as e: + if retries >= max_retries or e.response["Error"]["Code"] not in [ + "ThrottlingException", + "TooManyRequestsException", + ]: + raise + delay = min(initial_delay * (2**retries) + random.random(), max_delay) + print( + f"Retrying {func.__name__} in {delay:.2f} seconds... (Attempt {retries + 1}/{max_retries})" + ) + time.sleep(delay) + retries += 1 + + return wrapper + + return decorator + + class EndpointInvoker: def __init__( self,