Skip to content

Commit 5bc742f

Browse files
committed
feat: jsch jumpstart estimator support (aws#4439)
1 parent 2973f23 commit 5bc742f

File tree

18 files changed

+37
-34
lines changed

18 files changed

+37
-34
lines changed

src/sagemaker/jumpstart/accessors.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def get_model_specs(
257257
hub_arn: Optional[str] = None,
258258
s3_client: Optional[boto3.client] = None,
259259
model_type=JumpStartModelType.OPEN_WEIGHTS,
260+
hub_arn: Optional[str] = None,
260261
) -> JumpStartModelSpecs:
261262
"""Returns model specs from JumpStart models cache.
262263

src/sagemaker/jumpstart/cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
get_wildcard_model_version_msg,
4141
get_wildcard_proprietary_model_version_msg,
4242
)
43+
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4344
from sagemaker.jumpstart.parameters import (
4445
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4546
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,

src/sagemaker/jumpstart/constants.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
# works cross-partition
176-
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
175+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
177176
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
178177

179178
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook():
534534
model_version=model_version,
535535
hub_arn=hub_arn,
536536
model_type=self.model_type,
537+
hub_arn=hub_arn,
537538
tolerate_vulnerable_model=tolerate_vulnerable_model,
538539
tolerate_deprecated_model=tolerate_deprecated_model,
539540
role=role,

src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_init_kwargs(
8181
model_version: Optional[str] = None,
8282
hub_arn: Optional[str] = None,
8383
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
84+
hub_arn: Optional[str] = None,
8485
tolerate_vulnerable_model: Optional[bool] = None,
8586
tolerate_deprecated_model: Optional[bool] = None,
8687
region: Optional[str] = None,
@@ -140,6 +141,7 @@ def get_init_kwargs(
140141
model_version=model_version,
141142
hub_arn=hub_arn,
142143
model_type=model_type,
144+
hub_arn=hub_arn,
143145
role=role,
144146
region=region,
145147
instance_count=instance_count,

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def get_deploy_kwargs(
550550
model_version: Optional[str] = None,
551551
hub_arn: Optional[str] = None,
552552
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
553+
hub_arn: Optional[str] = None,
553554
region: Optional[str] = None,
554555
initial_instance_count: Optional[int] = None,
555556
instance_type: Optional[str] = None,
@@ -584,6 +585,7 @@ def get_deploy_kwargs(
584585
model_version=model_version,
585586
hub_arn=hub_arn,
586587
model_type=model_type,
588+
hub_arn=hub_arn,
587589
region=region,
588590
initial_instance_count=initial_instance_count,
589591
instance_type=instance_type,

src/sagemaker/jumpstart/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14201420
"model_version",
14211421
"hub_arn",
14221422
"model_type",
1423+
"hub_arn",
14231424
"initial_instance_count",
14241425
"instance_type",
14251426
"region",
@@ -1454,6 +1455,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14541455
"model_type",
14551456
"hub_arn",
14561457
"model_type",
1458+
"hub_arn",
14571459
"region",
14581460
"tolerate_deprecated_model",
14591461
"tolerate_vulnerable_model",
@@ -1500,6 +1502,7 @@ def __init__(
15001502
self.model_version = model_version
15011503
self.hub_arn = hub_arn
15021504
self.model_type = model_type
1505+
self.hub_arn = hub_arn
15031506
self.initial_instance_count = initial_instance_count
15041507
self.instance_type = instance_type
15051508
self.region = region
@@ -1536,6 +1539,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
15361539
"model_version",
15371540
"hub_arn",
15381541
"model_type",
1542+
"hub_arn",
15391543
"instance_type",
15401544
"instance_count",
15411545
"region",
@@ -1597,6 +1601,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
15971601
"model_version",
15981602
"hub_arn",
15991603
"model_type",
1604+
"hub_arn",
16001605
}
16011606

16021607
def __init__(
@@ -1726,6 +1731,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
17261731
"model_version",
17271732
"hub_arn",
17281733
"model_type",
1734+
"hub_arn",
17291735
"region",
17301736
"inputs",
17311737
"wait",
@@ -1742,6 +1748,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
17421748
"model_version",
17431749
"hub_arn",
17441750
"model_type",
1751+
"hub_arn",
17451752
"region",
17461753
"tolerate_deprecated_model",
17471754
"tolerate_vulnerable_model",
@@ -1770,6 +1777,7 @@ def __init__(
17701777
self.model_version = model_version
17711778
self.hub_arn = hub_arn
17721779
self.model_type = model_type
1780+
self.hub_arn = hub_arn
17731781
self.region = region
17741782
self.inputs = inputs
17751783
self.wait = wait

src/sagemaker/jumpstart/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
1919
import re
20+
from typing import Any, Dict, List, Set, Optional, Tuple, Union
2021
from urllib.parse import urlparse
2122
import boto3
2223
from packaging.version import Version

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

+2
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
453453
s3_client=mock_client,
454454
hub_arn=None,
455455
model_type=JumpStartModelType.OPEN_WEIGHTS,
456+
hub_arn=None,
456457
)
457458

458459
patched_get_model_specs.reset_mock()
@@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters(
516517
s3_client=mock_client,
517518
hub_arn=None,
518519
model_type=JumpStartModelType.OPEN_WEIGHTS,
520+
hub_arn=None,
519521
)
520522

521523
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri(
5656
s3_client=mock_client,
5757
hub_arn=None,
5858
model_type=JumpStartModelType.OPEN_WEIGHTS,
59+
hub_arn=None,
5960
)
6061
patched_verify_model_region_and_return_specs.assert_called_once()
6162

@@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri(
7879
s3_client=mock_client,
7980
hub_arn=None,
8081
model_type=JumpStartModelType.OPEN_WEIGHTS,
82+
hub_arn=None,
8183
)
8284
patched_verify_model_region_and_return_specs.assert_called_once()
8385

@@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri(
100102
s3_client=mock_client,
101103
hub_arn=None,
102104
model_type=JumpStartModelType.OPEN_WEIGHTS,
105+
hub_arn=None,
103106
)
104107
patched_verify_model_region_and_return_specs.assert_called_once()
105108

@@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri(
122125
s3_client=mock_client,
123126
hub_arn=None,
124127
model_type=JumpStartModelType.OPEN_WEIGHTS,
128+
hub_arn=None,
125129
)
126130
patched_verify_model_region_and_return_specs.assert_called_once()
127131

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
126126
model_type=JumpStartModelType.OPEN_WEIGHTS,
127127
hub_arn=None,
128128
s3_client=mock_client,
129-
model_type=JumpStartModelType.OPEN_WEIGHTS,
130129
)
131130

132131
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/test_accessors.py

-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
137137
> 0
138138
)
139139

140-
141140
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
142141
def test_jumpstart_models_cache_get_model_specs(mock_cache):
143142
mock_cache.get_specs = Mock()

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,5 @@ def test_get_model_url(
751751
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
752752
hub_arn=None,
753753
model_type=JumpStartModelType.OPEN_WEIGHTS,
754+
hub_arn=None,
754755
)

tests/unit/sagemaker/jumpstart/test_utils.py

-29
Original file line numberDiff line numberDiff line change
@@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str():
12141214
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12151215

12161216

1217-
def test_extract_info_from_hub_content_arn():
1218-
model_arn = (
1219-
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1220-
)
1221-
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1222-
"MockHub",
1223-
"us-west-2",
1224-
"my-mock-model",
1225-
"1.0.2",
1226-
)
1227-
1228-
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1229-
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1230-
1231-
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1232-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1233-
1234-
invalid_arn = "nonsense-string"
1235-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236-
1237-
invalid_arn = ""
1238-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1239-
1240-
invalid_arn = (
1241-
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1242-
)
1243-
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1244-
1245-
12461217
class TestIsValidModelId(TestCase):
12471218
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12481219
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubDataType,
25+
HubContentType,
2626
JumpStartCachedContentKey,
2727
JumpStartCachedContentValue,
2828
JumpStartModelSpecs,
@@ -253,6 +253,9 @@ def patched_retrieval_function(
253253
model_type=JumpStartModelType.PROPRIETARY,
254254
)
255255
)
256+
# TODO: Implement
257+
if datatype == HubContentType.HUB:
258+
return None
256259

257260
if datatype == HubContentType.MODEL:
258261
_, _, _, model_name, model_version = id_info.split("/")

tests/unit/sagemaker/model_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri(
5454
s3_client=mock_client,
5555
hub_arn=None,
5656
model_type=JumpStartModelType.OPEN_WEIGHTS,
57+
hub_arn=None,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

@@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri(
7374
s3_client=mock_client,
7475
hub_arn=None,
7576
model_type=JumpStartModelType.OPEN_WEIGHTS,
77+
hub_arn=None,
7678
)
7779
patched_verify_model_region_and_return_specs.assert_called_once()
7880

@@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri(
9395
s3_client=mock_client,
9496
hub_arn=None,
9597
model_type=JumpStartModelType.OPEN_WEIGHTS,
98+
hub_arn=None,
9699
)
97100
patched_verify_model_region_and_return_specs.assert_called_once()
98101

@@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri(
113116
s3_client=mock_client,
114117
hub_arn=None,
115118
model_type=JumpStartModelType.OPEN_WEIGHTS,
119+
hub_arn=None,
116120
)
117121
patched_verify_model_region_and_return_specs.assert_called_once()
118122

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements(
5757
s3_client=mock_client,
5858
hub_arn=None,
5959
model_type=JumpStartModelType.OPEN_WEIGHTS,
60+
hub_arn=None,
6061
)
6162
patched_get_model_specs.reset_mock()
6263

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_jumpstart_common_script_uri(
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
5757
model_type=JumpStartModelType.OPEN_WEIGHTS,
58+
hub_arn=None,
5859
)
5960
patched_verify_model_region_and_return_specs.assert_called_once()
6061

@@ -74,6 +75,7 @@ def test_jumpstart_common_script_uri(
7475
s3_client=mock_client,
7576
hub_arn=None,
7677
model_type=JumpStartModelType.OPEN_WEIGHTS,
78+
hub_arn=None,
7779
)
7880
patched_verify_model_region_and_return_specs.assert_called_once()
7981

@@ -94,6 +96,7 @@ def test_jumpstart_common_script_uri(
9496
s3_client=mock_client,
9597
hub_arn=None,
9698
model_type=JumpStartModelType.OPEN_WEIGHTS,
99+
hub_arn=None,
97100
)
98101
patched_verify_model_region_and_return_specs.assert_called_once()
99102

@@ -114,6 +117,7 @@ def test_jumpstart_common_script_uri(
114117
s3_client=mock_client,
115118
hub_arn=None,
116119
model_type=JumpStartModelType.OPEN_WEIGHTS,
120+
hub_arn=None,
117121
)
118122
patched_verify_model_region_and_return_specs.assert_called_once()
119123

0 commit comments

Comments
 (0)