Skip to content

Commit 7732979

Browse files
committed
feat: jsch jumpstart estimator support (aws#4439)
1 parent ea5bb5a commit 7732979

File tree

18 files changed

+64
-34
lines changed

18 files changed

+64
-34
lines changed

src/sagemaker/jumpstart/accessors.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def get_model_specs(
257257
hub_arn: Optional[str] = None,
258258
s3_client: Optional[boto3.client] = None,
259259
model_type=JumpStartModelType.OPEN_WEIGHTS,
260+
hub_arn: Optional[str] = None,
260261
) -> JumpStartModelSpecs:
261262
"""Returns model specs from JumpStart models cache.
262263

src/sagemaker/jumpstart/cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
get_wildcard_model_version_msg,
4040
get_wildcard_proprietary_model_version_msg,
4141
)
42+
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4243
from sagemaker.jumpstart.parameters import (
4344
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4445
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,

src/sagemaker/jumpstart/constants.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
# works cross-partition
176-
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
175+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
177176
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
178177

179178
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook():
534534
model_version=model_version,
535535
hub_arn=hub_arn,
536536
model_type=self.model_type,
537+
hub_arn=hub_arn,
537538
tolerate_vulnerable_model=tolerate_vulnerable_model,
538539
tolerate_deprecated_model=tolerate_deprecated_model,
539540
role=role,

src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_init_kwargs(
8181
model_version: Optional[str] = None,
8282
hub_arn: Optional[str] = None,
8383
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
84+
hub_arn: Optional[str] = None,
8485
tolerate_vulnerable_model: Optional[bool] = None,
8586
tolerate_deprecated_model: Optional[bool] = None,
8687
region: Optional[str] = None,
@@ -140,6 +141,7 @@ def get_init_kwargs(
140141
model_version=model_version,
141142
hub_arn=hub_arn,
142143
model_type=model_type,
144+
hub_arn=hub_arn,
143145
role=role,
144146
region=region,
145147
instance_count=instance_count,

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def get_deploy_kwargs(
549549
model_version: Optional[str] = None,
550550
hub_arn: Optional[str] = None,
551551
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
552+
hub_arn: Optional[str] = None,
552553
region: Optional[str] = None,
553554
initial_instance_count: Optional[int] = None,
554555
instance_type: Optional[str] = None,
@@ -583,6 +584,7 @@ def get_deploy_kwargs(
583584
model_version=model_version,
584585
hub_arn=hub_arn,
585586
model_type=model_type,
587+
hub_arn=hub_arn,
586588
region=region,
587589
initial_instance_count=initial_instance_count,
588590
instance_type=instance_type,

src/sagemaker/jumpstart/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
23542354
"model_version",
23552355
"hub_arn",
23562356
"model_type",
2357+
"hub_arn",
23572358
"initial_instance_count",
23582359
"instance_type",
23592360
"region",
@@ -2388,6 +2389,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
23882389
"model_type",
23892390
"hub_arn",
23902391
"model_type",
2392+
"hub_arn",
23912393
"region",
23922394
"tolerate_deprecated_model",
23932395
"tolerate_vulnerable_model",
@@ -2434,6 +2436,7 @@ def __init__(
24342436
self.model_version = model_version
24352437
self.hub_arn = hub_arn
24362438
self.model_type = model_type
2439+
self.hub_arn = hub_arn
24372440
self.initial_instance_count = initial_instance_count
24382441
self.instance_type = instance_type
24392442
self.region = region
@@ -2470,6 +2473,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
24702473
"model_version",
24712474
"hub_arn",
24722475
"model_type",
2476+
"hub_arn",
24732477
"instance_type",
24742478
"instance_count",
24752479
"region",
@@ -2531,6 +2535,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
25312535
"model_version",
25322536
"hub_arn",
25332537
"model_type",
2538+
"hub_arn",
25342539
}
25352540

25362541
def __init__(
@@ -2660,6 +2665,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
26602665
"model_version",
26612666
"hub_arn",
26622667
"model_type",
2668+
"hub_arn",
26632669
"region",
26642670
"inputs",
26652671
"wait",
@@ -2676,6 +2682,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
26762682
"model_version",
26772683
"hub_arn",
26782684
"model_type",
2685+
"hub_arn",
26792686
"region",
26802687
"tolerate_deprecated_model",
26812688
"tolerate_vulnerable_model",
@@ -2704,6 +2711,7 @@ def __init__(
27042711
self.model_version = model_version
27052712
self.hub_arn = hub_arn
27062713
self.model_type = model_type
2714+
self.hub_arn = hub_arn
27072715
self.region = region
27082716
self.inputs = inputs
27092717
self.wait = wait

src/sagemaker/jumpstart/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
19-
import re
2019
from urllib.parse import urlparse
2120
import boto3
2221
from packaging.version import Version

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

+2
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
453453
s3_client=mock_client,
454454
hub_arn=None,
455455
model_type=JumpStartModelType.OPEN_WEIGHTS,
456+
hub_arn=None,
456457
)
457458

458459
patched_get_model_specs.reset_mock()
@@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters(
516517
s3_client=mock_client,
517518
hub_arn=None,
518519
model_type=JumpStartModelType.OPEN_WEIGHTS,
520+
hub_arn=None,
519521
)
520522

521523
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri(
5656
s3_client=mock_client,
5757
hub_arn=None,
5858
model_type=JumpStartModelType.OPEN_WEIGHTS,
59+
hub_arn=None,
5960
)
6061
patched_verify_model_region_and_return_specs.assert_called_once()
6162

@@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri(
7879
s3_client=mock_client,
7980
hub_arn=None,
8081
model_type=JumpStartModelType.OPEN_WEIGHTS,
82+
hub_arn=None,
8183
)
8284
patched_verify_model_region_and_return_specs.assert_called_once()
8385

@@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri(
100102
s3_client=mock_client,
101103
hub_arn=None,
102104
model_type=JumpStartModelType.OPEN_WEIGHTS,
105+
hub_arn=None,
103106
)
104107
patched_verify_model_region_and_return_specs.assert_called_once()
105108

@@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri(
122125
s3_client=mock_client,
123126
hub_arn=None,
124127
model_type=JumpStartModelType.OPEN_WEIGHTS,
128+
hub_arn=None,
125129
)
126130
patched_verify_model_region_and_return_specs.assert_called_once()
127131

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
126126
model_type=JumpStartModelType.OPEN_WEIGHTS,
127127
hub_arn=None,
128128
s3_client=mock_client,
129-
model_type=JumpStartModelType.OPEN_WEIGHTS,
130129
)
131130

132131
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/test_accessors.py

+28
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,34 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
137137
> 0
138138
)
139139

140+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
141+
def test_jumpstart_models_cache_get_model_specs(mock_cache):
142+
mock_cache.get_specs = Mock()
143+
mock_cache.get_hub_model = Mock()
144+
model_id, version = "pytorch-ic-mobilenet-v2", "*"
145+
region = "us-west-2"
146+
147+
accessors.JumpStartModelsAccessor.get_model_specs(
148+
region=region, model_id=model_id, version=version
149+
)
150+
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
151+
mock_cache.get_hub_model.assert_not_called()
152+
153+
accessors.JumpStartModelsAccessor.get_model_specs(
154+
region=region,
155+
model_id=model_id,
156+
version=version,
157+
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
158+
)
159+
mock_cache.get_hub_model.assert_called_once_with(
160+
hub_model_arn=(
161+
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
162+
)
163+
)
164+
165+
# necessary because accessors is a static module
166+
reload(accessors)
167+
140168

141169
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
142170
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,5 @@ def test_get_model_url(
751751
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
752752
hub_arn=None,
753753
model_type=JumpStartModelType.OPEN_WEIGHTS,
754+
hub_arn=None,
754755
)

tests/unit/sagemaker/jumpstart/test_utils.py

-29
Original file line numberDiff line numberDiff line change
@@ -1206,35 +1206,6 @@ def test_mime_type_enum_from_str():
12061206
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12071207

12081208

1209-
def test_extract_info_from_hub_content_arn():
1210-
model_arn = (
1211-
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1212-
)
1213-
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1214-
"MockHub",
1215-
"us-west-2",
1216-
"my-mock-model",
1217-
"1.0.2",
1218-
)
1219-
1220-
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1221-
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1222-
1223-
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1224-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1225-
1226-
invalid_arn = "nonsense-string"
1227-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1228-
1229-
invalid_arn = ""
1230-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1231-
1232-
invalid_arn = (
1233-
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1234-
)
1235-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236-
1237-
12381209
class TestIsValidModelId(TestCase):
12391210
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12401211
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubDataType,
25+
HubContentType,
2626
JumpStartCachedContentKey,
2727
JumpStartCachedContentValue,
2828
JumpStartModelSpecs,
@@ -253,6 +253,9 @@ def patched_retrieval_function(
253253
model_type=JumpStartModelType.PROPRIETARY,
254254
)
255255
)
256+
# TODO: Implement
257+
if datatype == HubContentType.HUB:
258+
return None
256259

257260
raise ValueError(f"Bad value for datatype: {datatype}")
258261

tests/unit/sagemaker/model_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri(
5454
s3_client=mock_client,
5555
hub_arn=None,
5656
model_type=JumpStartModelType.OPEN_WEIGHTS,
57+
hub_arn=None,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

@@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri(
7374
s3_client=mock_client,
7475
hub_arn=None,
7576
model_type=JumpStartModelType.OPEN_WEIGHTS,
77+
hub_arn=None,
7678
)
7779
patched_verify_model_region_and_return_specs.assert_called_once()
7880

@@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri(
9395
s3_client=mock_client,
9496
hub_arn=None,
9597
model_type=JumpStartModelType.OPEN_WEIGHTS,
98+
hub_arn=None,
9699
)
97100
patched_verify_model_region_and_return_specs.assert_called_once()
98101

@@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri(
113116
s3_client=mock_client,
114117
hub_arn=None,
115118
model_type=JumpStartModelType.OPEN_WEIGHTS,
119+
hub_arn=None,
116120
)
117121
patched_verify_model_region_and_return_specs.assert_called_once()
118122

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements(
5757
s3_client=mock_client,
5858
hub_arn=None,
5959
model_type=JumpStartModelType.OPEN_WEIGHTS,
60+
hub_arn=None,
6061
)
6162
patched_get_model_specs.reset_mock()
6263

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_jumpstart_common_script_uri(
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
5757
model_type=JumpStartModelType.OPEN_WEIGHTS,
58+
hub_arn=None,
5859
)
5960
patched_verify_model_region_and_return_specs.assert_called_once()
6061

@@ -74,6 +75,7 @@ def test_jumpstart_common_script_uri(
7475
s3_client=mock_client,
7576
hub_arn=None,
7677
model_type=JumpStartModelType.OPEN_WEIGHTS,
78+
hub_arn=None,
7779
)
7880
patched_verify_model_region_and_return_specs.assert_called_once()
7981

@@ -94,6 +96,7 @@ def test_jumpstart_common_script_uri(
9496
s3_client=mock_client,
9597
hub_arn=None,
9698
model_type=JumpStartModelType.OPEN_WEIGHTS,
99+
hub_arn=None,
97100
)
98101
patched_verify_model_region_and_return_specs.assert_called_once()
99102

@@ -114,6 +117,7 @@ def test_jumpstart_common_script_uri(
114117
s3_client=mock_client,
115118
hub_arn=None,
116119
model_type=JumpStartModelType.OPEN_WEIGHTS,
120+
hub_arn=None,
117121
)
118122
patched_verify_model_region_and_return_specs.assert_called_once()
119123

0 commit comments

Comments
 (0)