File tree 3 files changed +20
-5
lines changed
src/sagemaker/jumpstart/hub
tests/unit/sagemaker/jumpstart
3 files changed +20
-5
lines changed Original file line number Diff line number Diff line change @@ -219,18 +219,17 @@ def get_hub_model_version(
219
219
except Exception as ex :
220
220
raise Exception (f"Failed calling list_hub_content_versions: { str (ex )} " )
221
221
222
- marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version (
223
- hub_content_summaries , hub_model_version
224
- )
225
-
226
222
try :
227
223
return _get_hub_model_version_for_open_weight_version (
228
224
hub_content_summaries , hub_model_version
229
225
)
230
226
except KeyError as e :
227
+ marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version (
228
+ hub_content_summaries , hub_model_version
229
+ )
231
230
if marketplace_hub_content_version :
232
231
return marketplace_hub_content_version
233
- raise e
232
+ raise
234
233
235
234
236
235
def _get_hub_model_version_for_open_weight_version (
Original file line number Diff line number Diff line change 9178
9178
"TrainingArtifactS3DataType" : "S3Prefix" ,
9179
9179
"TrainingArtifactCompressionType" : "None" ,
9180
9180
"TrainingArtifactUri" : "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz" , # noqa: E501
9181
+ "ModelTypes" : ["OPEN_WEIGHTS" , "PROPRIETARY" ],
9181
9182
"Hyperparameters" : [
9182
9183
{
9183
9184
"Name" : "peft_type" ,
Original file line number Diff line number Diff line change 13
13
from __future__ import absolute_import
14
14
import pytest
15
15
from sagemaker .jumpstart .hub .parser_utils import camel_to_snake
16
+ from sagemaker .jumpstart .hub .parsers import make_model_specs_from_describe_hub_content_response
17
+ from sagemaker .jumpstart .hub .interfaces import HubModelDocument
18
+ from tests .unit .sagemaker .jumpstart .constants import HUB_MODEL_DOCUMENT_DICTS
19
+ from unittest .mock import MagicMock
16
20
17
21
REGION = "us-east-1"
18
22
ACCOUNT_ID = "123456789123"
32
36
)
33
37
def test_parse_ (input_string , expected ):
34
38
assert expected == camel_to_snake (input_string )
39
+
40
+
41
+ def test_make_model_specs_from_describe_hub_content_response ():
42
+ mock_describe_response = MagicMock ()
43
+ region = "us-west-2"
44
+ mock_describe_response .get_hub_region .return_value = region
45
+ mock_describe_response .hub_content_version = "1.0.0"
46
+ json_obj = HUB_MODEL_DOCUMENT_DICTS ["huggingface-llm-gemma-2b-instruct" ]
47
+ mock_describe_response .hub_content_document = HubModelDocument (json_obj = json_obj , region = region )
48
+
49
+ make_model_specs_from_describe_hub_content_response (mock_describe_response )
You can’t perform that action at this time.
0 commit comments