Skip to content

Commit 15f730f

Browse files
committed
feat: tagging jumpstart models
1 parent b7c9586 commit 15f730f

File tree

5 files changed

+339
-1
lines changed

5 files changed

+339
-1
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@
112112
}
113113
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
114114

115+
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
116+
115117
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
116118

117119
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
@@ -149,3 +151,10 @@ class VariableScope(str, Enum):
149151

150152
CONTAINER = "container"
151153
ALGORITHM = "algorithm"
154+
155+
156+
class JumpStartTag(str, Enum):
157+
"""Enum class for tags to apply to JumpStart models."""
158+
159+
INFERENCE_MODEL_URI = "jumpstart-inference-model-uri"
160+
INFERENCE_SCRIPT_URI = "jumpstart-inference-script-uri"

src/sagemaker/jumpstart/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from __future__ import absolute_import
1515
from typing import Dict, List, Optional
1616
from packaging.version import Version
17+
from urllib.parse import urlparse
1718
import sagemaker
1819
from sagemaker.jumpstart import constants
1920
from sagemaker.jumpstart import accessors
2021
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
22+
from sagemaker.s3 import parse_s3_url
2123

2224

2325
def get_jumpstart_launched_regions_message() -> str:
@@ -136,3 +138,53 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) ->
136138
)
137139
return True
138140
return False
141+
142+
143+
def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
144+
"""Returns True if URI corresponds to a JumpStart-hosted model.
145+
146+
Args:
147+
uri (Optional[str]): uri for inference/training job.
148+
"""
149+
150+
bucket = None
151+
if urlparse(uri).scheme == "s3":
152+
bucket, _ = parse_s3_url(uri)
153+
154+
return bucket in constants.JUMPSTART_BUCKET_NAME_SET
155+
156+
157+
def add_jumpstart_tags(
158+
tags: Optional[List[Dict[str, str]]],
159+
inference_model_uri: Optional[str],
160+
inference_script_uri: Optional[str],
161+
) -> List[Dict[str, str]]:
162+
"""Adds tags for JumpStart models. Returns original tags for non-JumpStart
163+
models.
164+
165+
Args:
166+
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
167+
or training job.
168+
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
169+
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
170+
"""
171+
172+
if is_jumpstart_model_uri(inference_model_uri):
173+
if tags is None:
174+
tags = []
175+
tags.append(
176+
{
177+
constants.JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
178+
}
179+
)
180+
181+
if is_jumpstart_model_uri(inference_script_uri):
182+
if tags is None:
183+
tags = []
184+
tags.append(
185+
{
186+
constants.JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri,
187+
}
188+
)
189+
190+
return tags

src/sagemaker/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.deprecations import removed_kwargs
3535
from sagemaker.predictor import PredictorBase
3636
from sagemaker.transformer import Transformer
37+
from sagemaker.jumpstart.utils import add_jumpstart_tags
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -987,6 +988,10 @@ def deploy(
987988
removed_kwargs("update_endpoint", kwargs)
988989
self._init_sagemaker_session_if_does_not_exist(instance_type)
989990

991+
tags = add_jumpstart_tags(
992+
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
993+
)
994+
990995
if self.role is None:
991996
raise ValueError("Role can not be null for deploying a model")
992997

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,20 @@
1313
from __future__ import absolute_import
1414
from mock.mock import Mock, patch
1515
import pytest
16+
import random
1617
from sagemaker.jumpstart import utils
17-
from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET
18+
from sagemaker.jumpstart.constants import (
19+
JUMPSTART_BUCKET_NAME_SET,
20+
JUMPSTART_REGION_NAME_SET,
21+
JumpStartTag,
22+
)
1823
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
1924

2025

26+
def random_jumpstart_s3_uri(key):
27+
return f"s3://{random.choice(list(JUMPSTART_BUCKET_NAME_SET))}/{key}"
28+
29+
2130
def test_get_jumpstart_content_bucket():
2231
bad_region = "bad_region"
2332
assert bad_region not in JUMPSTART_REGION_NAME_SET
@@ -112,3 +121,168 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock):
112121
utils.get_sagemaker_version()
113122
utils.get_sagemaker_version()
114123
assert patched_parse_sm_version.called_only_once()
124+
125+
126+
def test_is_jumpstart_model_uri():
127+
128+
assert not utils.is_jumpstart_model_uri("fdsfdsf")
129+
assert not utils.is_jumpstart_model_uri("s3://not-jumpstart-bucket/sdfsdfds")
130+
assert not utils.is_jumpstart_model_uri("some/actual/localfile")
131+
132+
assert utils.is_jumpstart_model_uri(
133+
random_jumpstart_s3_uri("source_directory_tarballs/sourcedir.tar.gz")
134+
)
135+
assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key"))
136+
137+
138+
def test_add_jumpstart_tags():
139+
tags = None
140+
inference_model_uri = "dfsdfsd"
141+
inference_script_uri = "dfsdfs"
142+
assert (
143+
utils.add_jumpstart_tags(
144+
tags=tags,
145+
inference_model_uri=inference_model_uri,
146+
inference_script_uri=inference_script_uri,
147+
)
148+
is None
149+
)
150+
151+
tags = []
152+
inference_model_uri = "dfsdfsd"
153+
inference_script_uri = "dfsdfs"
154+
assert (
155+
utils.add_jumpstart_tags(
156+
tags=tags,
157+
inference_model_uri=inference_model_uri,
158+
inference_script_uri=inference_script_uri,
159+
)
160+
== []
161+
)
162+
163+
tags = [{"some": "tag"}]
164+
inference_model_uri = "dfsdfsd"
165+
inference_script_uri = "dfsdfs"
166+
assert (
167+
utils.add_jumpstart_tags(
168+
tags=tags,
169+
inference_model_uri=inference_model_uri,
170+
inference_script_uri=inference_script_uri,
171+
)
172+
== [{"some": "tag"}]
173+
)
174+
175+
tags = None
176+
inference_model_uri = random_jumpstart_s3_uri("random_key")
177+
inference_script_uri = "dfsdfs"
178+
assert (
179+
utils.add_jumpstart_tags(
180+
tags=tags,
181+
inference_model_uri=inference_model_uri,
182+
inference_script_uri=inference_script_uri,
183+
)
184+
== [{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri}]
185+
)
186+
187+
tags = []
188+
inference_model_uri = random_jumpstart_s3_uri("random_key")
189+
inference_script_uri = "dfsdfs"
190+
assert (
191+
utils.add_jumpstart_tags(
192+
tags=tags,
193+
inference_model_uri=inference_model_uri,
194+
inference_script_uri=inference_script_uri,
195+
)
196+
== [{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri}]
197+
)
198+
199+
tags = [{"some": "tag"}]
200+
inference_model_uri = random_jumpstart_s3_uri("random_key")
201+
inference_script_uri = "dfsdfs"
202+
assert utils.add_jumpstart_tags(
203+
tags=tags,
204+
inference_model_uri=inference_model_uri,
205+
inference_script_uri=inference_script_uri,
206+
) == [
207+
{"some": "tag"},
208+
{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri},
209+
]
210+
211+
tags = None
212+
inference_script_uri = random_jumpstart_s3_uri("random_key")
213+
inference_model_uri = "dfsdfs"
214+
assert (
215+
utils.add_jumpstart_tags(
216+
tags=tags,
217+
inference_model_uri=inference_model_uri,
218+
inference_script_uri=inference_script_uri,
219+
)
220+
== [{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri}]
221+
)
222+
223+
tags = []
224+
inference_script_uri = random_jumpstart_s3_uri("random_key")
225+
inference_model_uri = "dfsdfs"
226+
assert (
227+
utils.add_jumpstart_tags(
228+
tags=tags,
229+
inference_model_uri=inference_model_uri,
230+
inference_script_uri=inference_script_uri,
231+
)
232+
== [{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri}]
233+
)
234+
235+
tags = [{"some": "tag"}]
236+
inference_script_uri = random_jumpstart_s3_uri("random_key")
237+
inference_model_uri = "dfsdfs"
238+
assert utils.add_jumpstart_tags(
239+
tags=tags,
240+
inference_model_uri=inference_model_uri,
241+
inference_script_uri=inference_script_uri,
242+
) == [
243+
{"some": "tag"},
244+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
245+
]
246+
247+
tags = None
248+
inference_script_uri = random_jumpstart_s3_uri("random_key")
249+
inference_model_uri = random_jumpstart_s3_uri("random_key")
250+
assert utils.add_jumpstart_tags(
251+
tags=tags,
252+
inference_model_uri=inference_model_uri,
253+
inference_script_uri=inference_script_uri,
254+
) == [
255+
{
256+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
257+
},
258+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
259+
]
260+
261+
tags = []
262+
inference_script_uri = random_jumpstart_s3_uri("random_key")
263+
inference_model_uri = random_jumpstart_s3_uri("random_key")
264+
assert utils.add_jumpstart_tags(
265+
tags=tags,
266+
inference_model_uri=inference_model_uri,
267+
inference_script_uri=inference_script_uri,
268+
) == [
269+
{
270+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
271+
},
272+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
273+
]
274+
275+
tags = [{"some": "tag"}]
276+
inference_script_uri = random_jumpstart_s3_uri("random_key")
277+
inference_model_uri = random_jumpstart_s3_uri("random_key")
278+
assert utils.add_jumpstart_tags(
279+
tags=tags,
280+
inference_model_uri=inference_model_uri,
281+
inference_script_uri=inference_script_uri,
282+
) == [
283+
{"some": "tag"},
284+
{
285+
JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
286+
},
287+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
288+
]

0 commit comments

Comments
 (0)