Skip to content

fix: keep sagemaker_session from being overridden to None #5021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(
if s3_client_config
else boto3.client("s3", region_name=self._region)
)
self._sagemaker_session = sagemaker_session
# Fallback in case a caller overrides sagemaker_session to None
self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION

def set_region(self, region: str) -> None:
"""Set region for cache. Clears cache after new region is set."""
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def construct_hub_arn_from_name(
account_id: Optional[str] = None,
) -> str:
"""Constructs a Hub arn from the Hub name using default Session values."""
if session is None:
# session is overridden to none by some callers
session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION

account_id = account_id or session.account_id()
region = region or session.boto_region_name
Expand Down Expand Up @@ -211,6 +214,9 @@ def get_hub_model_version(
ClientError: If the specified model is not found in the hub.
KeyError: If the specified model version is not found.
"""
if sagemaker_session is None:
# sagemaker_session is overridden to none by some callers
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION

try:
hub_content_summaries = sagemaker_session.list_hub_content_versions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def test_jumpstart_hub_model(setup, add_model_references):
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)


def test_jumpstart_hub_model_with_default_session(setup, add_model_references):
model_version = "*"
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]

model_id = "catboost-classification-model"

sagemaker_session = get_sm_session()

model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name)

predictor = model.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)


def test_jumpstart_hub_gated_model(setup, add_model_references):

model_id = "meta-textgeneration-llama-3-2-1b"
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/sagemaker/jumpstart/hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from unittest.mock import patch, Mock
from sagemaker.jumpstart.types import HubArnExtractedInfo
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.hub import parser_utils, utils


Expand Down Expand Up @@ -80,6 +83,17 @@ def test_construct_hub_arn_from_name():
)


def test_construct_hub_arn_from_name_with_session_none():
hub_name = "my-cool-hub"
account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id()
boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name

assert (
utils.construct_hub_arn_from_name(hub_name=hub_name, session=None)
== f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}"
)


def test_construct_hub_model_arn_from_inputs():
model_name, version = "pytorch-ic-imagenet-v2", "1.0.2"
hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub"
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.jumpstart.cache import (
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JumpStartModelsCache,
)
from sagemaker.jumpstart.constants import (
Expand Down Expand Up @@ -57,6 +58,25 @@
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket


@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
@patch(
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
)
@patch("boto3.client")
def test_jumpstart_cache_init(mock_boto3_client):
cache = JumpStartModelsCache()
assert cache._region == "dummy-region"
assert cache.s3_bucket_name == "dummy-bucket"
assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY
assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION
mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region")

# Some callers override the session to None, should still be set to default
cache = JumpStartModelsCache(sagemaker_session=None)
assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION


@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
def test_jumpstart_cache_get_header():
Expand Down
Loading