Skip to content

Commit 0ee6065

Browse files
committed
feat: tagging jumpstart models
1 parent 574b9a2 commit 0ee6065

File tree

5 files changed

+456
-6
lines changed

5 files changed

+456
-6
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,115 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
from enum import Enum
1516
from typing import Set
1617
import boto3
1718
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1819

1920

20-
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
21+
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
22+
[
23+
JumpStartLaunchedRegionInfo(
24+
region_name="us-west-2",
25+
content_bucket="jumpstart-cache-prod-us-west-2",
26+
),
27+
JumpStartLaunchedRegionInfo(
28+
region_name="us-east-1",
29+
content_bucket="jumpstart-cache-prod-us-east-1",
30+
),
31+
JumpStartLaunchedRegionInfo(
32+
region_name="us-east-2",
33+
content_bucket="jumpstart-cache-prod-us-east-2",
34+
),
35+
JumpStartLaunchedRegionInfo(
36+
region_name="eu-west-1",
37+
content_bucket="jumpstart-cache-prod-eu-west-1",
38+
),
39+
JumpStartLaunchedRegionInfo(
40+
region_name="eu-central-1",
41+
content_bucket="jumpstart-cache-prod-eu-central-1",
42+
),
43+
JumpStartLaunchedRegionInfo(
44+
region_name="eu-north-1",
45+
content_bucket="jumpstart-cache-prod-eu-north-1",
46+
),
47+
JumpStartLaunchedRegionInfo(
48+
region_name="me-south-1",
49+
content_bucket="jumpstart-cache-prod-me-south-1",
50+
),
51+
JumpStartLaunchedRegionInfo(
52+
region_name="ap-south-1",
53+
content_bucket="jumpstart-cache-prod-ap-south-1",
54+
),
55+
JumpStartLaunchedRegionInfo(
56+
region_name="eu-west-3",
57+
content_bucket="jumpstart-cache-prod-eu-west-3",
58+
),
59+
JumpStartLaunchedRegionInfo(
60+
region_name="af-south-1",
61+
content_bucket="jumpstart-cache-prod-af-south-1",
62+
),
63+
JumpStartLaunchedRegionInfo(
64+
region_name="sa-east-1",
65+
content_bucket="jumpstart-cache-prod-sa-east-1",
66+
),
67+
JumpStartLaunchedRegionInfo(
68+
region_name="ap-east-1",
69+
content_bucket="jumpstart-cache-prod-ap-east-1",
70+
),
71+
JumpStartLaunchedRegionInfo(
72+
region_name="ap-northeast-2",
73+
content_bucket="jumpstart-cache-prod-ap-northeast-2",
74+
),
75+
JumpStartLaunchedRegionInfo(
76+
region_name="eu-west-2",
77+
content_bucket="jumpstart-cache-prod-eu-west-2",
78+
),
79+
JumpStartLaunchedRegionInfo(
80+
region_name="eu-south-1",
81+
content_bucket="jumpstart-cache-prod-eu-south-1",
82+
),
83+
JumpStartLaunchedRegionInfo(
84+
region_name="ap-northeast-1",
85+
content_bucket="jumpstart-cache-prod-ap-northeast-1",
86+
),
87+
JumpStartLaunchedRegionInfo(
88+
region_name="us-west-1",
89+
content_bucket="jumpstart-cache-prod-us-west-1",
90+
),
91+
JumpStartLaunchedRegionInfo(
92+
region_name="ap-southeast-1",
93+
content_bucket="jumpstart-cache-prod-ap-southeast-1",
94+
),
95+
JumpStartLaunchedRegionInfo(
96+
region_name="ap-southeast-2",
97+
content_bucket="jumpstart-cache-prod-ap-southeast-2",
98+
),
99+
JumpStartLaunchedRegionInfo(
100+
region_name="ca-central-1",
101+
content_bucket="jumpstart-cache-prod-ca-central-1",
102+
),
103+
JumpStartLaunchedRegionInfo(
104+
region_name="cn-north-1",
105+
content_bucket="jumpstart-cache-prod-cn-north-1",
106+
),
107+
]
108+
)
21109

22110
JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
23111
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS
24112
}
25113
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
26114

115+
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
116+
27117
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
28118

29119
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
120+
121+
122+
class JumpStartTag(str, Enum):
123+
"""Enum class for tags to apply to JumpStart models."""
124+
125+
INFERENCE_MODEL_URI = "jumpstart-inference-model-uri"
126+
INFERENCE_SCRIPT_URI = "jumpstart-inference-script-uri"

src/sagemaker/jumpstart/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15-
from typing import Dict, List
15+
from typing import Dict, List, Optional
16+
from urllib.parse import urlparse
1617
import semantic_version
1718
import sagemaker
1819
from sagemaker.jumpstart import constants
1920
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
21+
from sagemaker.s3 import parse_s3_url
2022

2123

2224
class SageMakerSettings(object):
@@ -128,3 +130,53 @@ def parse_sagemaker_version() -> str:
128130
semantic_version.Version(parsed_version)
129131

130132
return parsed_version
133+
134+
135+
def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
136+
"""Returns True if URI corresponds to a JumpStart-hosted model.
137+
138+
Args:
139+
uri (Optional[str]): uri for inference/training job.
140+
"""
141+
142+
bucket = None
143+
if urlparse(uri).scheme == "s3":
144+
bucket, _ = parse_s3_url(uri)
145+
146+
return bucket in constants.JUMPSTART_BUCKET_NAME_SET
147+
148+
149+
def add_jumpstart_tags(
150+
tags: Optional[List[Dict[str, str]]],
151+
inference_model_uri: Optional[str],
152+
inference_script_uri: Optional[str],
153+
) -> List[Dict[str, str]]:
154+
"""Adds tags for JumpStart models. Returns original tags for non-JumpStart
155+
models.
156+
157+
Args:
158+
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
159+
or training job.
160+
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
161+
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
162+
"""
163+
164+
if is_jumpstart_model_uri(inference_model_uri):
165+
if tags is None:
166+
tags = []
167+
tags.append(
168+
{
169+
constants.JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
170+
}
171+
)
172+
173+
if is_jumpstart_model_uri(inference_script_uri):
174+
if tags is None:
175+
tags = []
176+
tags.append(
177+
{
178+
constants.JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri,
179+
}
180+
)
181+
182+
return tags

src/sagemaker/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.deprecations import removed_kwargs
3535
from sagemaker.predictor import PredictorBase
3636
from sagemaker.transformer import Transformer
37+
from sagemaker.jumpstart.utils import add_jumpstart_tags
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -982,6 +983,10 @@ def deploy(
982983
removed_kwargs("update_endpoint", kwargs)
983984
self._init_sagemaker_session_if_does_not_exist(instance_type)
984985

986+
tags = add_jumpstart_tags(
987+
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
988+
)
989+
985990
if self.role is None:
986991
raise ValueError("Role can not be null for deploying a model")
987992

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,20 @@
1313
from __future__ import absolute_import
1414
from mock.mock import Mock, patch
1515
import pytest
16+
import random
1617
from sagemaker.jumpstart import utils
17-
from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET
18+
from sagemaker.jumpstart.constants import (
19+
JUMPSTART_BUCKET_NAME_SET,
20+
JUMPSTART_REGION_NAME_SET,
21+
JumpStartTag,
22+
)
1823
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
1924

2025

26+
def random_jumpstart_s3_uri(key):
27+
return f"s3://{random.choice(list(JUMPSTART_BUCKET_NAME_SET))}/{key}"
28+
29+
2130
def test_get_jumpstart_content_bucket():
2231
bad_region = "bad_region"
2332
assert bad_region not in JUMPSTART_REGION_NAME_SET
@@ -112,3 +121,168 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock):
112121
utils.get_sagemaker_version()
113122
utils.get_sagemaker_version()
114123
assert patched_parse_sm_version.called_only_once()
124+
125+
126+
def test_is_jumpstart_model_uri():
127+
128+
assert not utils.is_jumpstart_model_uri("fdsfdsf")
129+
assert not utils.is_jumpstart_model_uri("s3://not-jumpstart-bucket/sdfsdfds")
130+
assert not utils.is_jumpstart_model_uri("some/actual/localfile")
131+
132+
assert utils.is_jumpstart_model_uri(
133+
random_jumpstart_s3_uri("source_directory_tarballs/sourcedir.tar.gz")
134+
)
135+
assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key"))
136+
137+
138+
def test_add_jumpstart_tags():
139+
tags = None
140+
inference_model_uri = "dfsdfsd"
141+
inference_script_uri = "dfsdfs"
142+
assert (
143+
utils.add_jumpstart_tags(
144+
tags=tags,
145+
inference_model_uri=inference_model_uri,
146+
inference_script_uri=inference_script_uri,
147+
)
148+
is None
149+
)
150+
151+
tags = []
152+
inference_model_uri = "dfsdfsd"
153+
inference_script_uri = "dfsdfs"
154+
assert (
155+
utils.add_jumpstart_tags(
156+
tags=tags,
157+
inference_model_uri=inference_model_uri,
158+
inference_script_uri=inference_script_uri,
159+
)
160+
== []
161+
)
162+
163+
tags = [{"some": "tag"}]
164+
inference_model_uri = "dfsdfsd"
165+
inference_script_uri = "dfsdfs"
166+
assert (
167+
utils.add_jumpstart_tags(
168+
tags=tags,
169+
inference_model_uri=inference_model_uri,
170+
inference_script_uri=inference_script_uri,
171+
)
172+
== [{"some": "tag"}]
173+
)
174+
175+
tags = None
176+
inference_model_uri = random_jumpstart_s3_uri("random_key")
177+
inference_script_uri = "dfsdfs"
178+
assert (
179+
utils.add_jumpstart_tags(
180+
tags=tags,
181+
inference_model_uri=inference_model_uri,
182+
inference_script_uri=inference_script_uri,
183+
)
184+
== [{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri}]
185+
)
186+
187+
tags = []
188+
inference_model_uri = random_jumpstart_s3_uri("random_key")
189+
inference_script_uri = "dfsdfs"
190+
assert (
191+
utils.add_jumpstart_tags(
192+
tags=tags,
193+
inference_model_uri=inference_model_uri,
194+
inference_script_uri=inference_script_uri,
195+
)
196+
== [{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri}]
197+
)
198+
199+
tags = [{"some": "tag"}]
200+
inference_model_uri = random_jumpstart_s3_uri("random_key")
201+
inference_script_uri = "dfsdfs"
202+
assert utils.add_jumpstart_tags(
203+
tags=tags,
204+
inference_model_uri=inference_model_uri,
205+
inference_script_uri=inference_script_uri,
206+
) == [
207+
{"some": "tag"},
208+
{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri},
209+
]
210+
211+
tags = None
212+
inference_script_uri = random_jumpstart_s3_uri("random_key")
213+
inference_model_uri = "dfsdfs"
214+
assert (
215+
utils.add_jumpstart_tags(
216+
tags=tags,
217+
inference_model_uri=inference_model_uri,
218+
inference_script_uri=inference_script_uri,
219+
)
220+
== [{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri}]
221+
)
222+
223+
tags = []
224+
inference_script_uri = random_jumpstart_s3_uri("random_key")
225+
inference_model_uri = "dfsdfs"
226+
assert (
227+
utils.add_jumpstart_tags(
228+
tags=tags,
229+
inference_model_uri=inference_model_uri,
230+
inference_script_uri=inference_script_uri,
231+
)
232+
== [{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri}]
233+
)
234+
235+
tags = [{"some": "tag"}]
236+
inference_script_uri = random_jumpstart_s3_uri("random_key")
237+
inference_model_uri = "dfsdfs"
238+
assert utils.add_jumpstart_tags(
239+
tags=tags,
240+
inference_model_uri=inference_model_uri,
241+
inference_script_uri=inference_script_uri,
242+
) == [
243+
{"some": "tag"},
244+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
245+
]
246+
247+
tags = None
248+
inference_script_uri = random_jumpstart_s3_uri("random_key")
249+
inference_model_uri = random_jumpstart_s3_uri("random_key")
250+
assert utils.add_jumpstart_tags(
251+
tags=tags,
252+
inference_model_uri=inference_model_uri,
253+
inference_script_uri=inference_script_uri,
254+
) == [
255+
{
256+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
257+
},
258+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
259+
]
260+
261+
tags = []
262+
inference_script_uri = random_jumpstart_s3_uri("random_key")
263+
inference_model_uri = random_jumpstart_s3_uri("random_key")
264+
assert utils.add_jumpstart_tags(
265+
tags=tags,
266+
inference_model_uri=inference_model_uri,
267+
inference_script_uri=inference_script_uri,
268+
) == [
269+
{
270+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
271+
},
272+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
273+
]
274+
275+
tags = [{"some": "tag"}]
276+
inference_script_uri = random_jumpstart_s3_uri("random_key")
277+
inference_model_uri = random_jumpstart_s3_uri("random_key")
278+
assert utils.add_jumpstart_tags(
279+
tags=tags,
280+
inference_model_uri=inference_model_uri,
281+
inference_script_uri=inference_script_uri,
282+
) == [
283+
{"some": "tag"},
284+
{
285+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
286+
},
287+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
288+
]

0 commit comments

Comments
 (0)