Skip to content

Commit a8a2453

Browse files
committed
fix: Adding more code coverage
1 parent 273449c commit a8a2453

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

src/sagemaker/jumpstart/hub/utils.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,17 @@ def get_hub_model_version(
219219
except Exception as ex:
220220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
221221

222-
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
223-
hub_content_summaries, hub_model_version
224-
)
225-
226222
try:
227223
return _get_hub_model_version_for_open_weight_version(
228224
hub_content_summaries, hub_model_version
229225
)
230226
except KeyError as e:
227+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
228+
hub_content_summaries, hub_model_version
229+
)
231230
if marketplace_hub_content_version:
232231
return marketplace_hub_content_version
233-
raise e
232+
raise
234233

235234

236235
def _get_hub_model_version_for_open_weight_version(

tests/unit/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -9178,6 +9178,7 @@
91789178
"TrainingArtifactS3DataType": "S3Prefix",
91799179
"TrainingArtifactCompressionType": "None",
91809180
"TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501
9181+
"ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"],
91819182
"Hyperparameters": [
91829183
{
91839184
"Name": "peft_type",

tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from __future__ import absolute_import
1414
import pytest
1515
from sagemaker.jumpstart.hub.parser_utils import camel_to_snake
16+
from sagemaker.jumpstart.hub.parsers import make_model_specs_from_describe_hub_content_response
17+
from sagemaker.jumpstart.hub.interfaces import HubModelDocument
18+
from tests.unit.sagemaker.jumpstart.constants import HUB_MODEL_DOCUMENT_DICTS
19+
from unittest.mock import MagicMock
1620

1721
REGION = "us-east-1"
1822
ACCOUNT_ID = "123456789123"
@@ -32,3 +36,14 @@
3236
)
3337
def test_parse_(input_string, expected):
3438
assert expected == camel_to_snake(input_string)
39+
40+
41+
def test_make_model_specs_from_describe_hub_content_response():
42+
mock_describe_response = MagicMock()
43+
region = "us-west-2"
44+
mock_describe_response.get_hub_region.return_value = region
45+
mock_describe_response.hub_content_version = "1.0.0"
46+
json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"]
47+
mock_describe_response.hub_content_document = HubModelDocument(json_obj=json_obj, region=region)
48+
49+
make_model_specs_from_describe_hub_content_response(mock_describe_response)

0 commit comments

Comments
 (0)