Skip to content

Commit febc83f

Browse files
Captainiabenieric
authored andcommitted
feat: tag JumpStart resource with config names (aws#4608)
* tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests
1 parent 0c526a6 commit febc83f

File tree

13 files changed

+313
-162
lines changed

13 files changed

+313
-162
lines changed

src/sagemaker/jumpstart/enums.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class JumpStartTag(str, Enum):
9292
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
9393
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9494
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
95+
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"
9596

9697

9798
class SerializerType(str, Enum):

src/sagemaker/jumpstart/estimator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3535
from sagemaker.jumpstart.factory.model import get_default_predictor
36-
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
36+
from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
3737
from sagemaker.jumpstart.types import JumpStartMetadataConfig
3838
from sagemaker.jumpstart.utils import (
3939
get_jumpstart_configs,
@@ -734,10 +734,10 @@ def attach(
734734
ValueError: if the model ID or version cannot be inferred from the training job.
735735
736736
"""
737-
737+
config_name = None
738738
if model_id is None:
739739

740-
model_id, model_version = get_model_id_version_from_training_job(
740+
model_id, model_version, config_name = get_model_info_from_training_job(
741741
training_job_name=training_job_name, sagemaker_session=sagemaker_session
742742
)
743743

@@ -758,6 +758,7 @@ def attach(
758758
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
759759
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
760760
sagemaker_session=sagemaker_session,
761+
config_name=config_name,
761762
)
762763

763764
# eula was already accepted if the model was successfully trained
@@ -1111,7 +1112,7 @@ def deploy(
11111112
tolerate_deprecated_model=self.tolerate_deprecated_model,
11121113
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
11131114
sagemaker_session=self.sagemaker_session,
1114-
# config_name=self.config_name,
1115+
config_name=self.config_name,
11151116
)
11161117

11171118
# If a predictor class was passed, do not mutate predictor

src/sagemaker/jumpstart/factory/estimator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
480480

481481
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
482482
kwargs.tags = add_jumpstart_model_id_version_tags(
483-
kwargs.tags, kwargs.model_id, full_model_version
483+
kwargs.tags,
484+
kwargs.model_id,
485+
full_model_version,
486+
config_name=kwargs.config_name,
484487
)
485488
return kwargs
486489

src/sagemaker/jumpstart/factory/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
496496

497497
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
498498
kwargs.tags = add_jumpstart_model_id_version_tags(
499-
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
499+
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
500500
)
501501

502502
return kwargs

src/sagemaker/jumpstart/session_utils.py

+30-26
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from sagemaker.utils import aws_partition
2323

2424

25-
def get_model_id_version_from_endpoint(
25+
def get_model_info_from_endpoint(
2626
endpoint_name: str,
2727
inference_component_name: Optional[str] = None,
2828
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29-
) -> Tuple[str, str, Optional[str]]:
30-
"""Given an endpoint and optionally inference component names, return the model ID and version.
29+
) -> Tuple[str, str, Optional[str], Optional[str]]:
30+
"""Optionally inference component names, return the model ID, version and config name.
3131
3232
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
3333
and version. A third string element is included in the tuple for any inferred inference
@@ -46,30 +46,32 @@ def get_model_id_version_from_endpoint(
4646
(
4747
model_id,
4848
model_version,
49-
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
49+
config_name,
50+
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
5051
inference_component_name, sagemaker_session
5152
)
5253

5354
else:
5455
(
5556
model_id,
5657
model_version,
58+
config_name,
5759
inference_component_name,
58-
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
60+
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
5961
endpoint_name, sagemaker_session
6062
)
6163

6264
else:
63-
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
65+
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
6466
endpoint_name, inference_component_name, sagemaker_session
6567
)
66-
return model_id, model_version, inference_component_name
68+
return model_id, model_version, inference_component_name, config_name
6769

6870

69-
def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
71+
def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
7072
endpoint_name: str, sagemaker_session: Session
71-
) -> Tuple[str, str, str]:
72-
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
73+
) -> Tuple[str, str, str, str]:
74+
"""Derives the model ID, version, config name and inferred inference component name.
7375
7476
This function assumes the endpoint corresponds to an inference-component-based endpoint.
7577
An endpoint is inference-component-based if and only if the associated endpoint config
@@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
98100
)
99101
inference_component_name = inference_component_names[0]
100102
return (
101-
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
103+
*_get_model_info_from_inference_component_endpoint_with_inference_component_name(
102104
inference_component_name, sagemaker_session
103105
),
104106
inference_component_name,
105107
)
106108

107109

108-
def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
110+
def _get_model_info_from_inference_component_endpoint_with_inference_component_name(
109111
inference_component_name: str, sagemaker_session: Session
110112
):
111113
"""Returns the model ID and version inferred from a SageMaker inference component.
@@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
123125
f"inference-component/{inference_component_name}"
124126
)
125127

126-
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
128+
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
127129
inference_component_arn, sagemaker_session
128130
)
129131

@@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
134136
"when retrieving default predictor for this inference component."
135137
)
136138

137-
return model_id, model_version
139+
return model_id, model_version, config_name
138140

139141

140-
def _get_model_id_version_from_model_based_endpoint(
142+
def _get_model_info_from_model_based_endpoint(
141143
endpoint_name: str,
142144
inference_component_name: Optional[str],
143145
sagemaker_session: Session,
144-
) -> Tuple[str, str]:
145-
"""Returns the model ID and version inferred from a model-based endpoint.
146+
) -> Tuple[str, str, Optional[str]]:
147+
"""Returns the model ID, version and config name inferred from a model-based endpoint.
146148
147149
Raises:
148150
ValueError: If an inference component name is supplied, or if the endpoint does
@@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(
161163

162164
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
163165

164-
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
166+
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
165167
endpoint_arn, sagemaker_session
166168
)
167169

@@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
172174
"predictor for this endpoint."
173175
)
174176

175-
return model_id, model_version
177+
return model_id, model_version, config_name
176178

177179

178-
def get_model_id_version_from_training_job(
180+
def get_model_info_from_training_job(
179181
training_job_name: str,
180182
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
181-
) -> Tuple[str, str]:
182-
"""Returns the model ID and version inferred from a training job.
183+
) -> Tuple[str, str, Optional[str]]:
184+
"""Returns the model ID and version and config name inferred from a training job.
183185
184186
Raises:
185187
ValueError: If the training job does not have tags from which the model ID
@@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
194196
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
195197
)
196198

197-
model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
198-
training_job_arn, sagemaker_session
199-
)
199+
(
200+
model_id,
201+
inferred_model_version,
202+
config_name,
203+
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
200204

201205
model_version = inferred_model_version or None
202206

@@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
207211
"for this training job."
208212
)
209213

210-
return model_id, model_version
214+
return model_id, model_version, config_name

src/sagemaker/jumpstart/types.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1163,20 +1163,17 @@ def get_top_config_from_ranking(
11631163
) -> Optional[JumpStartMetadataConfig]:
11641164
"""Gets the best the config based on config ranking.
11651165
1166+
Fallback to use the ordering in config names if
1167+
ranking is not available.
11661168
Args:
11671169
ranking_name (str):
11681170
The ranking name that config priority is based on.
11691171
instance_type (Optional[str]):
11701172
The instance type which the config selection is based on.
11711173
11721174
Raises:
1173-
ValueError: If the config exists but missing config ranking.
11741175
NotImplementedError: If the scope is unrecognized.
11751176
"""
1176-
if self.configs and (
1177-
not self.config_rankings or not self.config_rankings.get(ranking_name)
1178-
):
1179-
raise ValueError(f"Config exists but missing config ranking {ranking_name}.")
11801177

11811178
if self.scope == JumpStartScriptScope.INFERENCE:
11821179
instance_type_attribute = "supported_inference_instance_types"
@@ -1185,8 +1182,14 @@ def get_top_config_from_ranking(
11851182
else:
11861183
raise NotImplementedError(f"Unknown script scope {self.scope}")
11871184

1188-
rankings = self.config_rankings.get(ranking_name)
1189-
for config_name in rankings.rankings:
1185+
if self.configs and (
1186+
not self.config_rankings or not self.config_rankings.get(ranking_name)
1187+
):
1188+
ranked_config_names = sorted(list(self.configs.keys()))
1189+
else:
1190+
rankings = self.config_rankings.get(ranking_name)
1191+
ranked_config_names = rankings.rankings
1192+
for config_name in ranked_config_names:
11901193
resolved_config = self.configs[config_name].resolved_config
11911194
if instance_type and instance_type not in getattr(
11921195
resolved_config, instance_type_attribute

src/sagemaker/jumpstart/utils.py

+62-33
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def add_single_jumpstart_tag(
318318
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
319319
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
320320
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
321+
or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags)
321322
)
322323
if is_uri
323324
else False
@@ -353,6 +354,7 @@ def add_jumpstart_model_id_version_tags(
353354
model_id: str,
354355
model_version: str,
355356
model_type: Optional[enums.JumpStartModelType] = None,
357+
config_name: Optional[str] = None,
356358
) -> List[TagsDict]:
357359
"""Add custom model ID and version tags to JumpStart related resources."""
358360
if model_id is None or model_version is None:
@@ -376,6 +378,13 @@ def add_jumpstart_model_id_version_tags(
376378
tags,
377379
is_uri=False,
378380
)
381+
if config_name:
382+
tags = add_single_jumpstart_tag(
383+
config_name,
384+
enums.JumpStartTag.MODEL_CONFIG_NAME,
385+
tags,
386+
is_uri=False,
387+
)
379388
return tags
380389

381390

@@ -800,52 +809,72 @@ def validate_model_id_and_get_type(
800809
return None
801810

802811

812+
def _extract_value_from_list_of_tags(
813+
tag_keys: List[str],
814+
list_tags_result: List[str],
815+
resource_name: str,
816+
resource_arn: str,
817+
):
818+
"""Extracts value from list of tags with check of duplicate tags.
819+
820+
Returns None if no value is found.
821+
"""
822+
resolved_value = None
823+
for tag_key in tag_keys:
824+
try:
825+
value_from_tag = get_tag_value(tag_key, list_tags_result)
826+
except KeyError:
827+
continue
828+
if value_from_tag is not None:
829+
if resolved_value is not None and value_from_tag != resolved_value:
830+
constants.JUMPSTART_LOGGER.warning(
831+
"Found multiple %s tags on the following resource: %s",
832+
resource_name,
833+
resource_arn,
834+
)
835+
resolved_value = None
836+
break
837+
resolved_value = value_from_tag
838+
return resolved_value
839+
840+
803841
def get_jumpstart_model_id_version_from_resource_arn(
804842
resource_arn: str,
805843
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
806-
) -> Tuple[Optional[str], Optional[str]]:
807-
"""Returns the JumpStart model ID and version if in resource tags.
844+
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
845+
"""Returns the JumpStart model ID, version and config name if in resource tags.
808846
809-
Returns 'None' if model ID or version cannot be inferred from tags.
847+
Returns 'None' if model ID or version or config name cannot be inferred from tags.
810848
"""
811849

812850
list_tags_result = sagemaker_session.list_tags(resource_arn)
813851

814-
model_id: Optional[str] = None
815-
model_version: Optional[str] = None
816-
817852
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
818853
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
854+
model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME]
819855

820-
for model_id_key in model_id_keys:
821-
try:
822-
model_id_from_tag = get_tag_value(model_id_key, list_tags_result)
823-
except KeyError:
824-
continue
825-
if model_id_from_tag is not None:
826-
if model_id is not None and model_id_from_tag != model_id:
827-
constants.JUMPSTART_LOGGER.warning(
828-
"Found multiple model ID tags on the following resource: %s", resource_arn
829-
)
830-
model_id = None
831-
break
832-
model_id = model_id_from_tag
856+
model_id: Optional[str] = _extract_value_from_list_of_tags(
857+
tag_keys=model_id_keys,
858+
list_tags_result=list_tags_result,
859+
resource_name="model ID",
860+
resource_arn=resource_arn,
861+
)
833862

834-
for model_version_key in model_version_keys:
835-
try:
836-
model_version_from_tag = get_tag_value(model_version_key, list_tags_result)
837-
except KeyError:
838-
continue
839-
if model_version_from_tag is not None:
840-
if model_version is not None and model_version_from_tag != model_version:
841-
constants.JUMPSTART_LOGGER.warning(
842-
"Found multiple model version tags on the following resource: %s", resource_arn
843-
)
844-
model_version = None
845-
break
846-
model_version = model_version_from_tag
863+
model_version: Optional[str] = _extract_value_from_list_of_tags(
864+
tag_keys=model_version_keys,
865+
list_tags_result=list_tags_result,
866+
resource_name="model version",
867+
resource_arn=resource_arn,
868+
)
869+
870+
config_name: Optional[str] = _extract_value_from_list_of_tags(
871+
tag_keys=model_config_name_keys,
872+
list_tags_result=list_tags_result,
873+
resource_name="model config name",
874+
resource_arn=resource_arn,
875+
)
847876

848-
return model_id, model_version
877+
return model_id, model_version, config_name
849878

850879

851880
def get_region_fallback(

0 commit comments

Comments
 (0)