Skip to content

Commit bc51dc1

Browse files
authored
Use separate tags for inference and training configs (#4635)
* Use separate tags for inference and training * format * format * format * format
1 parent d13f9e2 commit bc51dc1

File tree

14 files changed

+417
-180
lines changed

14 files changed

+417
-180
lines changed

src/sagemaker/jumpstart/enums.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ class JumpStartTag(str, Enum):
9292
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
9393
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9494
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
95-
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"
95+
96+
INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
97+
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
9698

9799

98100
class SerializerType(str, Enum):

src/sagemaker/jumpstart/estimator.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def attach(
733733
config_name = None
734734
if model_id is None:
735735

736-
model_id, model_version, config_name = get_model_info_from_training_job(
736+
model_id, model_version, _, config_name = get_model_info_from_training_job(
737737
training_job_name=training_job_name, sagemaker_session=sagemaker_session
738738
)
739739

@@ -1139,7 +1139,9 @@ def set_training_config(self, config_name: str) -> None:
11391139
Args:
11401140
config_name (str): The name of the config.
11411141
"""
1142-
self.__init__(**self.init_kwargs, config_name=config_name)
1142+
self.__init__(
1143+
model_id=self.model_id, model_version=self.model_version, config_name=config_name
1144+
)
11431145

11441146
def __str__(self) -> str:
11451147
"""Overriding str(*) method to make more human-readable."""

src/sagemaker/jumpstart/factory/estimator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
JumpStartModelInitKwargs,
6262
)
6363
from sagemaker.jumpstart.utils import (
64-
add_jumpstart_model_id_version_tags,
64+
add_jumpstart_model_info_tags,
6565
get_eula_message,
6666
update_dict_if_key_not_present,
6767
resolve_estimator_sagemaker_config_field,
@@ -477,11 +477,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
477477
).version
478478

479479
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
480-
kwargs.tags = add_jumpstart_model_id_version_tags(
480+
kwargs.tags = add_jumpstart_model_info_tags(
481481
kwargs.tags,
482482
kwargs.model_id,
483483
full_model_version,
484484
config_name=kwargs.config_name,
485+
scope=JumpStartScriptScope.TRAINING,
485486
)
486487
return kwargs
487488

src/sagemaker/jumpstart/factory/model.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
JumpStartModelRegisterKwargs,
4545
)
4646
from sagemaker.jumpstart.utils import (
47-
add_jumpstart_model_id_version_tags,
47+
add_jumpstart_model_info_tags,
4848
update_dict_if_key_not_present,
4949
resolve_model_sagemaker_config_field,
5050
verify_model_region_and_return_specs,
@@ -495,8 +495,13 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
495495
).version
496496

497497
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
498-
kwargs.tags = add_jumpstart_model_id_version_tags(
499-
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
498+
kwargs.tags = add_jumpstart_model_info_tags(
499+
kwargs.tags,
500+
kwargs.model_id,
501+
full_model_version,
502+
kwargs.model_type,
503+
config_name=kwargs.config_name,
504+
scope=JumpStartScriptScope.INFERENCE,
500505
)
501506

502507
return kwargs

src/sagemaker/jumpstart/session_utils.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional, Tuple
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1919

20-
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
20+
from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn
2121
from sagemaker.session import Session
2222
from sagemaker.utils import aws_partition
2323

@@ -26,7 +26,7 @@ def get_model_info_from_endpoint(
2626
endpoint_name: str,
2727
inference_component_name: Optional[str] = None,
2828
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29-
) -> Tuple[str, str, Optional[str], Optional[str]]:
29+
) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]:
3030
"""Optionally inference component names, return the model ID, version and config name.
3131
3232
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
@@ -46,7 +46,8 @@ def get_model_info_from_endpoint(
4646
(
4747
model_id,
4848
model_version,
49-
config_name,
49+
inference_config_name,
50+
training_config_name,
5051
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
5152
inference_component_name, sagemaker_session
5253
)
@@ -55,17 +56,29 @@ def get_model_info_from_endpoint(
5556
(
5657
model_id,
5758
model_version,
58-
config_name,
59+
inference_config_name,
60+
training_config_name,
5961
inference_component_name,
6062
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
6163
endpoint_name, sagemaker_session
6264
)
6365

6466
else:
65-
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
67+
(
68+
model_id,
69+
model_version,
70+
inference_config_name,
71+
training_config_name,
72+
) = _get_model_info_from_model_based_endpoint(
6673
endpoint_name, inference_component_name, sagemaker_session
6774
)
68-
return model_id, model_version, inference_component_name, config_name
75+
return (
76+
model_id,
77+
model_version,
78+
inference_component_name,
79+
inference_config_name,
80+
training_config_name,
81+
)
6982

7083

7184
def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
@@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
125138
f"inference-component/{inference_component_name}"
126139
)
127140

128-
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
129-
inference_component_arn, sagemaker_session
130-
)
141+
(
142+
model_id,
143+
model_version,
144+
inference_config_name,
145+
training_config_name,
146+
) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session)
131147

132148
if not model_id:
133149
raise ValueError(
@@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
136152
"when retrieving default predictor for this inference component."
137153
)
138154

139-
return model_id, model_version, config_name
155+
return model_id, model_version, inference_config_name, training_config_name
140156

141157

142158
def _get_model_info_from_model_based_endpoint(
143159
endpoint_name: str,
144160
inference_component_name: Optional[str],
145161
sagemaker_session: Session,
146-
) -> Tuple[str, str, Optional[str]]:
162+
) -> Tuple[str, str, Optional[str], Optional[str]]:
147163
"""Returns the model ID, version and config name inferred from a model-based endpoint.
148164
149165
Raises:
@@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint(
163179

164180
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
165181

166-
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
167-
endpoint_arn, sagemaker_session
168-
)
182+
(
183+
model_id,
184+
model_version,
185+
inference_config_name,
186+
training_config_name,
187+
) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session)
169188

170189
if not model_id:
171190
raise ValueError(
@@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint(
174193
"predictor for this endpoint."
175194
)
176195

177-
return model_id, model_version, config_name
196+
return model_id, model_version, inference_config_name, training_config_name
178197

179198

180199
def get_model_info_from_training_job(
181200
training_job_name: str,
182201
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
183-
) -> Tuple[str, str, Optional[str]]:
202+
) -> Tuple[str, str, Optional[str], Optional[str]]:
184203
"""Returns the model ID and version and config name inferred from a training job.
185204
186205
Raises:
@@ -199,8 +218,9 @@ def get_model_info_from_training_job(
199218
(
200219
model_id,
201220
inferred_model_version,
202-
config_name,
203-
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
221+
inference_config_name,
222+
trainig_config_name,
223+
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
204224

205225
model_version = inferred_model_version or None
206226

@@ -211,4 +231,4 @@ def get_model_info_from_training_job(
211231
"for this training job."
212232
)
213233

214-
return model_id, model_version, config_name
234+
return model_id, model_version, inference_config_name, trainig_config_name

src/sagemaker/jumpstart/utils.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def add_single_jumpstart_tag(
320320
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
321321
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
322322
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
323-
or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags)
323+
or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
324+
or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
324325
)
325326
if is_uri
326327
else False
@@ -351,12 +352,13 @@ def get_jumpstart_base_name_if_jumpstart_model(
351352
return None
352353

353354

354-
def add_jumpstart_model_id_version_tags(
355+
def add_jumpstart_model_info_tags(
355356
tags: Optional[List[TagsDict]],
356357
model_id: str,
357358
model_version: str,
358359
model_type: Optional[enums.JumpStartModelType] = None,
359360
config_name: Optional[str] = None,
361+
scope: enums.JumpStartScriptScope = None,
360362
) -> List[TagsDict]:
361363
"""Add custom model ID and version tags to JumpStart related resources."""
362364
if model_id is None or model_version is None:
@@ -380,10 +382,17 @@ def add_jumpstart_model_id_version_tags(
380382
tags,
381383
is_uri=False,
382384
)
383-
if config_name:
385+
if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
384386
tags = add_single_jumpstart_tag(
385387
config_name,
386-
enums.JumpStartTag.MODEL_CONFIG_NAME,
388+
enums.JumpStartTag.INFERENCE_CONFIG_NAME,
389+
tags,
390+
is_uri=False,
391+
)
392+
if config_name and scope == enums.JumpStartScriptScope.TRAINING:
393+
tags = add_single_jumpstart_tag(
394+
config_name,
395+
enums.JumpStartTag.TRAINING_CONFIG_NAME,
387396
tags,
388397
is_uri=False,
389398
)
@@ -840,10 +849,10 @@ def _extract_value_from_list_of_tags(
840849
return resolved_value
841850

842851

843-
def get_jumpstart_model_id_version_from_resource_arn(
852+
def get_jumpstart_model_info_from_resource_arn(
844853
resource_arn: str,
845854
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
846-
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
855+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
847856
"""Returns the JumpStart model ID, version and config name if in resource tags.
848857
849858
Returns 'None' if model ID or version or config name cannot be inferred from tags.
@@ -853,7 +862,8 @@ def get_jumpstart_model_id_version_from_resource_arn(
853862

854863
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
855864
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
856-
model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME]
865+
inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
866+
training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]
857867

858868
model_id: Optional[str] = _extract_value_from_list_of_tags(
859869
tag_keys=model_id_keys,
@@ -869,14 +879,21 @@ def get_jumpstart_model_id_version_from_resource_arn(
869879
resource_arn=resource_arn,
870880
)
871881

872-
config_name: Optional[str] = _extract_value_from_list_of_tags(
873-
tag_keys=model_config_name_keys,
882+
inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
883+
tag_keys=inference_config_name_keys,
884+
list_tags_result=list_tags_result,
885+
resource_name="inference config name",
886+
resource_arn=resource_arn,
887+
)
888+
889+
training_config_name: Optional[str] = _extract_value_from_list_of_tags(
890+
tag_keys=training_config_name_keys,
874891
list_tags_result=list_tags_result,
875-
resource_name="model config name",
892+
resource_name="training config name",
876893
resource_arn=resource_arn,
877894
)
878895

879-
return model_id, model_version, config_name
896+
return model_id, model_version, inference_config_name, training_config_name
880897

881898

882899
def get_region_fallback(

src/sagemaker/predictor.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def retrieve_default(
8282
inferred_model_version,
8383
inferred_inference_component_name,
8484
inferred_config_name,
85+
_,
8586
) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session)
8687

8788
if not inferred_model_id:

0 commit comments

Comments
 (0)