File tree 3 files changed +31
-1
lines changed
tests/unit/sagemaker/jumpstart 3 files changed +31
-1
lines changed Original file line number Diff line number Diff line change @@ -268,7 +268,7 @@ def get_model_specs(
268
268
hub_model_arn = utils .construct_hub_model_arn_from_inputs (
269
269
hub_arn = hub_arn , model_name = model_id , version = version
270
270
)
271
- return JumpStartModelsAccessor ._cache .get_hub_model (hub_model_arn )
271
+ return JumpStartModelsAccessor ._cache .get_hub_model (hub_model_arn = hub_model_arn )
272
272
273
273
return JumpStartModelsAccessor ._cache .get_specs ( # type: ignore
274
274
model_id = model_id , semantic_version_str = version
Original file line number Diff line number Diff line change @@ -72,6 +72,35 @@ def test_jumpstart_models_cache_get_fxs(mock_cache):
72
72
reload (accessors )
73
73
74
74
75
+ @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
76
+ def test_jumpstart_models_cache_get_model_specs (mock_cache ):
77
+ mock_cache .get_specs = Mock ()
78
+ mock_cache .get_hub_model = Mock ()
79
+ model_id , version = "pytorch-ic-mobilenet-v2" , "*"
80
+ region = "us-west-2"
81
+
82
+ accessors .JumpStartModelsAccessor .get_model_specs (
83
+ region = region , model_id = model_id , version = version
84
+ )
85
+ mock_cache .get_specs .assert_called_once_with (model_id = model_id , semantic_version_str = version )
86
+ mock_cache .get_hub_model .assert_not_called ()
87
+
88
+ accessors .JumpStartModelsAccessor .get_model_specs (
89
+ region = region ,
90
+ model_id = model_id ,
91
+ version = version ,
92
+ hub_arn = f"arn:aws:sagemaker:{ region } :123456789123:hub/my-mock-hub" ,
93
+ )
94
+ mock_cache .get_hub_model .assert_called_once_with (
95
+ hub_model_arn = (
96
+ f"arn:aws:sagemaker:{ region } :123456789123:hub-content/my-mock-hub/Model/{ model_id } /{ version } "
97
+ )
98
+ )
99
+
100
+ # necessary because accessors is a static module
101
+ reload (accessors )
102
+
103
+
75
104
@patch ("sagemaker.jumpstart.cache.JumpStartModelsCache" )
76
105
def test_jumpstart_models_cache_set_reset_fxs (mock_model_cache : Mock ):
77
106
Original file line number Diff line number Diff line change @@ -148,6 +148,7 @@ def get_spec_from_base_spec(
148
148
semantic_version_str : str = None ,
149
149
version : str = None ,
150
150
hub_arn : Optional [str ] = None ,
151
+ hub_model_arn : Optional [str ] = None ,
151
152
s3_client : boto3 .client = None ,
152
153
) -> JumpStartModelSpecs :
153
154
You can’t perform that action at this time.
0 commit comments