Skip to content

Commit 354b33e

Browse files
committed
add important unit test
1 parent a39ae5f commit 354b33e

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def get_model_specs(
268268
hub_model_arn = utils.construct_hub_model_arn_from_inputs(
269269
hub_arn=hub_arn, model_name=model_id, version=version
270270
)
271-
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn)
271+
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
272272

273273
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
274274
model_id=model_id, semantic_version_str=version

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,35 @@ def test_jumpstart_models_cache_get_fxs(mock_cache):
7272
reload(accessors)
7373

7474

75+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
76+
def test_jumpstart_models_cache_get_model_specs(mock_cache):
77+
mock_cache.get_specs = Mock()
78+
mock_cache.get_hub_model = Mock()
79+
model_id, version = "pytorch-ic-mobilenet-v2", "*"
80+
region = "us-west-2"
81+
82+
accessors.JumpStartModelsAccessor.get_model_specs(
83+
region=region, model_id=model_id, version=version
84+
)
85+
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
86+
mock_cache.get_hub_model.assert_not_called()
87+
88+
accessors.JumpStartModelsAccessor.get_model_specs(
89+
region=region,
90+
model_id=model_id,
91+
version=version,
92+
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
93+
)
94+
mock_cache.get_hub_model.assert_called_once_with(
95+
hub_model_arn=(
96+
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
97+
)
98+
)
99+
100+
# necessary because accessors is a static module
101+
reload(accessors)
102+
103+
75104
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
76105
def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock):
77106

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def get_spec_from_base_spec(
148148
semantic_version_str: str = None,
149149
version: str = None,
150150
hub_arn: Optional[str] = None,
151+
hub_model_arn: Optional[str] = None,
151152
s3_client: boto3.client = None,
152153
) -> JumpStartModelSpecs:
153154

0 commit comments

Comments
 (0)