Skip to content

Commit 5cba331

Browse files
committed
change: add support for preexisting js tag, add tests, cleanup docstrings
1 parent 21d1853 commit 5cba331

File tree

4 files changed

+72
-18
lines changed

4 files changed

+72
-18
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class VariableScope(str, Enum):
154154

155155

156156
class JumpStartTag(str, Enum):
157-
"""Enum class for tags to apply to JumpStart models."""
157+
"""Enum class for tag keys to apply to JumpStart models."""
158158

159-
INFERENCE_MODEL_URI = "jumpstart-inference-model-uri"
160-
INFERENCE_SCRIPT_URI = "jumpstart-inference-script-uri"
159+
INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
160+
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"

src/sagemaker/jumpstart/utils.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
from functools import reduce
1516
from typing import Dict, List, Optional
16-
from packaging.version import Version
1717
from urllib.parse import urlparse
18+
from packaging.version import Version
1819
import sagemaker
1920
from sagemaker.jumpstart import constants
2021
from sagemaker.jumpstart import accessors
@@ -154,13 +155,26 @@ def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
154155
return bucket in constants.JUMPSTART_BUCKET_NAME_SET
155156

156157

158+
def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool:
159+
"""Returns True if ``tag_key`` is in the ``tag_array``.
160+
161+
Args:
162+
tag_key (str): the tag key to check if it's already in the ``tag_array``.
163+
tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``.
164+
"""
165+
if len(tag_array) == 0:
166+
return False
167+
return tag_key in reduce(lambda a, b: set(a.keys()).union(set(b.keys())), tag_array)
168+
169+
157170
def add_jumpstart_tags(
158171
tags: Optional[List[Dict[str, str]]],
159172
inference_model_uri: Optional[str],
160173
inference_script_uri: Optional[str],
161174
) -> List[Dict[str, str]]:
162-
"""Adds tags for JumpStart models. Returns original tags for non-JumpStart
163-
models.
175+
"""Add custom tags to JumpStart models, return the updated tags.
176+
177+
No-op if this is not a JumpStart model related resource.
164178
165179
Args:
166180
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
@@ -172,19 +186,21 @@ def add_jumpstart_tags(
172186
if is_jumpstart_model_uri(inference_model_uri):
173187
if tags is None:
174188
tags = []
175-
tags.append(
176-
{
177-
constants.JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
178-
}
179-
)
189+
if not tag_key_in_array(constants.JumpStartTag.INFERENCE_MODEL_URI.value, tags):
190+
tags.append(
191+
{
192+
constants.JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri,
193+
}
194+
)
180195

181196
if is_jumpstart_model_uri(inference_script_uri):
182197
if tags is None:
183198
tags = []
184-
tags.append(
185-
{
186-
constants.JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri,
187-
}
188-
)
199+
if not tag_key_in_array(constants.JumpStartTag.INFERENCE_SCRIPT_URI.value, tags):
200+
tags.append(
201+
{
202+
constants.JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri,
203+
}
204+
)
189205

190206
return tags

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,42 @@ def test_add_jumpstart_tags():
286286
},
287287
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
288288
]
289+
290+
tags = [{JumpStartTag.INFERENCE_MODEL_URI.value: "garbage-value"}]
291+
inference_script_uri = random_jumpstart_s3_uri("random_key")
292+
inference_model_uri = random_jumpstart_s3_uri("random_key")
293+
assert utils.add_jumpstart_tags(
294+
tags=tags,
295+
inference_model_uri=inference_model_uri,
296+
inference_script_uri=inference_script_uri,
297+
) == [
298+
{JumpStartTag.INFERENCE_MODEL_URI.value: "garbage-value"},
299+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: inference_script_uri},
300+
]
301+
302+
tags = [{JumpStartTag.INFERENCE_SCRIPT_URI.value: "garbage-value"}]
303+
inference_script_uri = random_jumpstart_s3_uri("random_key")
304+
inference_model_uri = random_jumpstart_s3_uri("random_key")
305+
assert utils.add_jumpstart_tags(
306+
tags=tags,
307+
inference_model_uri=inference_model_uri,
308+
inference_script_uri=inference_script_uri,
309+
) == [
310+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: "garbage-value"},
311+
{JumpStartTag.INFERENCE_MODEL_URI.value: inference_model_uri},
312+
]
313+
314+
tags = [
315+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: "garbage-value"},
316+
{JumpStartTag.INFERENCE_MODEL_URI.value: "garbage-value-2"},
317+
]
318+
inference_script_uri = random_jumpstart_s3_uri("random_key")
319+
inference_model_uri = random_jumpstart_s3_uri("random_key")
320+
assert utils.add_jumpstart_tags(
321+
tags=tags,
322+
inference_model_uri=inference_model_uri,
323+
inference_script_uri=inference_script_uri,
324+
) == [
325+
{JumpStartTag.INFERENCE_SCRIPT_URI.value: "garbage-value"},
326+
{JumpStartTag.INFERENCE_MODEL_URI.value: "garbage-value-2"},
327+
]

tests/unit/sagemaker/model/test_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from sagemaker.model import FrameworkModel, Model
2121
from sagemaker.huggingface.model import HuggingFaceModel
2222
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JumpStartTag
23-
from sagemaker.model import FrameworkModel, Model
2423
from sagemaker.mxnet.model import MXNetModel
2524
from sagemaker.pytorch.model import PyTorchModel
2625
from sagemaker.sklearn.model import SKLearnModel
@@ -482,7 +481,7 @@ def test_script_mode_model_tags_jumpstart_models(repack_model, sagemaker_session
482481
},
483482
]
484483

485-
non_jumpstart_source_dir = f"s3://blah/blah/blah"
484+
non_jumpstart_source_dir = "s3://blah/blah/blah"
486485
t = Model(
487486
entry_point=ENTRY_POINT_INFERENCE,
488487
role=ROLE,

0 commit comments

Comments
 (0)