Skip to content

Commit 18d388c

Browse files
committed
chore: add unit tests
1 parent 0fb849f commit 18d388c

File tree

3 files changed

+146
-3
lines changed

3 files changed

+146
-3
lines changed

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
import unittest
1818
from inspect import signature
19+
from mock import Mock
1920

2021
import pytest
2122
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -30,12 +31,16 @@
3031
from sagemaker.jumpstart.artifacts.metric_definitions import (
3132
_retrieve_default_training_metric_definitions,
3233
)
33-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
34+
from sagemaker.jumpstart.constants import (
35+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
JUMPSTART_DEFAULT_REGION_NAME,
37+
)
3438
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag
3539

3640
from sagemaker.jumpstart.estimator import JumpStartEstimator
3741

3842
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
43+
from sagemaker.session import Session
3944
from sagemaker.session_settings import SessionSettings
4045
from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version
4146
from sagemaker.model import Model
@@ -44,6 +49,7 @@
4449
get_special_model_spec,
4550
overwrite_dictionary,
4651
)
52+
import boto3
4753

4854

4955
execution_role = "fake role! do not use!"
@@ -1773,6 +1779,56 @@ def test_model_artifact_variant_estimator(
17731779
],
17741780
)
17751781

1782+
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
1783+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
1784+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
1785+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
1786+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
1787+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1788+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1789+
def test_jumpstart_estimator_session(
1790+
self,
1791+
mock_get_model_specs: mock.Mock,
1792+
mock_is_valid_model_id: mock.Mock,
1793+
mock_deploy,
1794+
mock_fit,
1795+
mock_init,
1796+
get_default_predictor,
1797+
):
1798+
1799+
mock_is_valid_model_id.return_value = True
1800+
1801+
model_id, _ = "js-trainable-model", "*"
1802+
1803+
mock_get_model_specs.side_effect = get_special_model_spec
1804+
1805+
region = "eu-west-1" # some non-default region
1806+
1807+
if region == JUMPSTART_DEFAULT_REGION_NAME:
1808+
region = "us-west-2"
1809+
1810+
session = Session(boto_session=boto3.session.Session(region_name=region))
1811+
1812+
assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME
1813+
1814+
session.get_caller_identity_arn = Mock(return_value="blah")
1815+
1816+
estimator = JumpStartEstimator(model_id=model_id, sagemaker_session=session)
1817+
estimator.fit()
1818+
1819+
estimator.deploy()
1820+
1821+
assert len(mock_get_model_specs.call_args_list) > 1
1822+
1823+
regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list}
1824+
1825+
assert len(regions) == 1
1826+
assert list(regions)[0] == region
1827+
1828+
s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list}
1829+
assert len(s3_clients) == 1
1830+
assert list(s3_clients)[0] == session.s3_client
1831+
17761832

17771833
def test_jumpstart_estimator_requires_model_id():
17781834
with pytest.raises(ValueError):

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
from typing import Optional, Set
1616
from unittest import mock
1717
import unittest
18-
from mock import MagicMock
18+
from mock import MagicMock, Mock
1919
import pytest
2020
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2121
from sagemaker.jumpstart.artifacts.environment_variables import (
2222
_retrieve_default_environment_variables,
2323
)
24-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
24+
from sagemaker.jumpstart.constants import (
25+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
26+
JUMPSTART_DEFAULT_REGION_NAME,
27+
)
2528
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag
2629

2730
from sagemaker.jumpstart.model import JumpStartModel
2831
from sagemaker.model import Model
2932
from sagemaker.predictor import Predictor
33+
from sagemaker.session import Session
3034
from sagemaker.session_settings import SessionSettings
3135
from sagemaker.enums import EndpointType
3236
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
@@ -36,6 +40,7 @@
3640
overwrite_dictionary,
3741
get_special_model_spec_for_inference_component_based_endpoint,
3842
)
43+
import boto3
3944

4045
execution_role = "fake role! do not use!"
4146
region = "us-west-2"
@@ -1252,6 +1257,52 @@ def test_model_registry_accept_and_response_types(
12521257
response_types=["application/json;verbose", "application/json"],
12531258
)
12541259

1260+
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
1261+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
1262+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1263+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
1264+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1265+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1266+
def test_jumpstart_model_session(
1267+
self,
1268+
mock_get_model_specs: mock.Mock,
1269+
mock_is_valid_model_id: mock.Mock,
1270+
mock_deploy,
1271+
mock_init,
1272+
get_default_predictor,
1273+
):
1274+
1275+
mock_is_valid_model_id.return_value = True
1276+
1277+
model_id, _ = "model_data_s3_prefix_model", "*"
1278+
1279+
mock_get_model_specs.side_effect = get_special_model_spec
1280+
1281+
region = "eu-west-1" # some non-default region
1282+
1283+
if region == JUMPSTART_DEFAULT_REGION_NAME:
1284+
region = "us-west-2"
1285+
1286+
session = Session(boto_session=boto3.session.Session(region_name=region))
1287+
1288+
assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME
1289+
1290+
session.get_caller_identity_arn = Mock(return_value="blah")
1291+
1292+
model = JumpStartModel(model_id=model_id, sagemaker_session=session)
1293+
model.deploy()
1294+
1295+
assert len(mock_get_model_specs.call_args_list) > 1
1296+
1297+
regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list}
1298+
1299+
assert len(regions) == 1
1300+
assert list(regions)[0] == region
1301+
1302+
s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list}
1303+
assert len(s3_clients) == 1
1304+
assert list(s3_clients)[0] == session.s3_client
1305+
12551306

12561307
def test_jumpstart_model_requires_model_id():
12571308
with pytest.raises(ValueError):

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from unittest.mock import Mock, call, mock_open
1919
from botocore.stub import Stubber
2020
import botocore
21+
import boto3
2122

2223
from mock.mock import MagicMock
2324
import pytest
@@ -27,6 +28,7 @@
2728
from sagemaker.jumpstart.constants import (
2829
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
2930
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
31+
JUMPSTART_DEFAULT_REGION_NAME,
3032
)
3133
from sagemaker.jumpstart.types import (
3234
JumpStartModelHeader,
@@ -854,3 +856,37 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
854856
),
855857
]
856858
)
859+
860+
861+
@pytest.mark.parametrize(
862+
"s3_bucket_name, s3_client, region",
863+
[
864+
(
865+
"jumpstart-cache-prod",
866+
boto3.client("s3", region_name="blah-blah"),
867+
JUMPSTART_DEFAULT_REGION_NAME,
868+
),
869+
(
870+
"jumpstart-cache-prod-us-west-2",
871+
boto3.client("s3", region_name="us-west-2"),
872+
"us-west-2",
873+
),
874+
("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), "us-east-2"),
875+
],
876+
)
877+
def test_get_region_fallback_success(s3_bucket_name, s3_client, region):
878+
cache = JumpStartModelsCache()
879+
assert region == cache._get_region_fallback(s3_bucket_name, s3_client)
880+
881+
882+
@pytest.mark.parametrize(
883+
"s3_bucket_name, s3_client",
884+
[
885+
("jumpstart-cache-prod-us-west-2", boto3.client("s3", region_name="us-east-2")),
886+
("jumpstart-cache-prod-us-west-2-us-east-2", boto3.client("s3", region_name="us-east-2")),
887+
],
888+
)
889+
def test_get_region_fallback_failure(s3_bucket_name, s3_client):
890+
cache = JumpStartModelsCache()
891+
with pytest.raises(ValueError):
892+
cache._get_region_fallback(s3_bucket_name, s3_client)

0 commit comments

Comments
 (0)