Skip to content

Commit f8d4fb5

Browse files
committed
Change: More pythonic tags
1 parent cebfd71 commit f8d4fb5

31 files changed

+225
-179
lines changed

src/sagemaker/algorithm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.session import Session
2929
from sagemaker.workflow.entities import PipelineVariable
3030
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
31+
from sagemaker.utils import format_tags, Tags
3132

3233
from sagemaker.workflow import is_pipeline_variable
3334

@@ -58,7 +59,7 @@ def __init__(
5859
base_job_name: Optional[str] = None,
5960
sagemaker_session: Optional[Session] = None,
6061
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
61-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
62+
tags: Optional[Tags] = None,
6263
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
6364
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
6465
model_uri: Optional[str] = None,
@@ -121,7 +122,7 @@ def __init__(
121122
interactions with Amazon SageMaker APIs and any other AWS services needed. If
122123
not specified, the estimator creates one using the default
123124
AWS configuration chain.
124-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
125+
tags (Union[Tags]): Tags for
125126
labeling a training job. For more, see
126127
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
127128
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified
@@ -170,7 +171,7 @@ def __init__(
170171
output_kms_key=output_kms_key,
171172
base_job_name=base_job_name,
172173
sagemaker_session=sagemaker_session,
173-
tags=tags,
174+
tags=format_tags(tags),
174175
subnets=subnets,
175176
security_group_ids=security_group_ids,
176177
model_uri=model_uri,

src/sagemaker/automl/automl.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from sagemaker.job import _Job
3030
from sagemaker.session import Session
31-
from sagemaker.utils import name_from_base, resolve_value_from_config
31+
from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags
3232
from sagemaker.workflow.entities import PipelineVariable
3333
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3434

@@ -127,7 +127,7 @@ def __init__(
127127
total_job_runtime_in_seconds: Optional[int] = None,
128128
job_objective: Optional[Dict[str, str]] = None,
129129
generate_candidate_definitions_only: Optional[bool] = False,
130-
tags: Optional[List[Dict[str, str]]] = None,
130+
tags: Optional[Tags] = None,
131131
content_type: Optional[str] = None,
132132
s3_data_type: Optional[str] = None,
133133
feature_specification_s3_uri: Optional[str] = None,
@@ -167,8 +167,7 @@ def __init__(
167167
In the format of: {"MetricName": str}
168168
generate_candidate_definitions_only (bool): Whether to generates
169169
possible candidates without training the models.
170-
tags (List[dict[str, str]]): The list of tags to attach to this
171-
specific endpoint.
170+
tags (Optional[Tags]): Tags to attach to this specific endpoint.
172171
content_type (str): The content type of the data from the input source.
173172
s3_data_type (str): The data type for S3 data source.
174173
Valid values: ManifestFile or S3Prefix.
@@ -203,7 +202,7 @@ def __init__(
203202
self.target_attribute_name = target_attribute_name
204203
self.job_objective = job_objective
205204
self.generate_candidate_definitions_only = generate_candidate_definitions_only
206-
self.tags = tags
205+
self.tags = format_tags(tags)
207206
self.content_type = content_type
208207
self.s3_data_type = s3_data_type
209208
self.feature_specification_s3_uri = feature_specification_s3_uri

src/sagemaker/clarify.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sagemaker.session import Session
3434
from sagemaker.network import NetworkConfig
3535
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
36+
from sagemaker.utils import format_tags, Tags
3637

3738
logger = logging.getLogger(__name__)
3839

@@ -1417,7 +1418,7 @@ def __init__(
14171418
max_runtime_in_seconds: Optional[int] = None,
14181419
sagemaker_session: Optional[Session] = None,
14191420
env: Optional[Dict[str, str]] = None,
1420-
tags: Optional[List[Dict[str, str]]] = None,
1421+
tags: Optional[Tags] = None,
14211422
network_config: Optional[NetworkConfig] = None,
14221423
job_name_prefix: Optional[str] = None,
14231424
version: Optional[str] = None,
@@ -1454,7 +1455,7 @@ def __init__(
14541455
using the default AWS configuration chain.
14551456
env (dict[str, str]): Environment variables to be passed to
14561457
the processing jobs (default: None).
1457-
tags (list[dict]): List of tags to be passed to the processing job
1458+
tags (Optional[Tags]): Tags to be passed to the processing job
14581459
(default: None). For more, see
14591460
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
14601461
network_config (:class:`~sagemaker.network.NetworkConfig`):
@@ -1482,7 +1483,7 @@ def __init__(
14821483
None, # We set method-specific job names below.
14831484
sagemaker_session,
14841485
env,
1485-
tags,
1486+
format_tags(tags),
14861487
network_config,
14871488
)
14881489

src/sagemaker/estimator.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
to_string,
9999
check_and_get_run_experiment_config,
100100
resolve_value_from_config,
101+
format_tags,
102+
Tags,
101103
)
102104
from sagemaker.workflow import is_pipeline_variable
103105
from sagemaker.workflow.entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144146
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
145147
base_job_name: Optional[str] = None,
146148
sagemaker_session: Optional[Session] = None,
147-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
149+
tags: Optional[Tags] = None,
148150
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
149151
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
150152
model_uri: Optional[str] = None,
@@ -269,8 +271,8 @@ def __init__(
269271
manages interactions with Amazon SageMaker APIs and any other
270272
AWS services needed. If not specified, the estimator creates one
271273
using the default AWS configuration chain.
272-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]):
273-
List of tags for labeling a training job. For more, see
274+
tags (Optional[Tags]):
275+
Tags for labeling a training job. For more, see
274276
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
275277
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
276278
specified training job will be created without VPC config.
@@ -601,6 +603,7 @@ def __init__(
601603
else:
602604
self.sagemaker_session = sagemaker_session or Session()
603605

606+
tags = format_tags(tags)
604607
self.tags = (
605608
add_jumpstart_uri_tags(
606609
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
@@ -1347,7 +1350,7 @@ def compile_model(
13471350
framework=None,
13481351
framework_version=None,
13491352
compile_max_run=15 * 60,
1350-
tags=None,
1353+
tags: Optional[Tags] = None,
13511354
target_platform_os=None,
13521355
target_platform_arch=None,
13531356
target_platform_accelerator=None,
@@ -1373,7 +1376,7 @@ def compile_model(
13731376
compile_max_run (int): Timeout in seconds for compilation (default:
13741377
15 * 60). After this amount of time Amazon SageMaker Neo
13751378
terminates the compilation job regardless of its current status.
1376-
tags (list[dict]): List of tags for labeling a compilation job. For
1379+
tags (list[dict]): Tags for labeling a compilation job. For
13771380
more, see
13781381
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
13791382
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1415,7 +1418,7 @@ def compile_model(
14151418
input_shape,
14161419
output_path,
14171420
self.role,
1418-
tags,
1421+
format_tags(tags),
14191422
self._compilation_job_name(),
14201423
compile_max_run,
14211424
framework=framework,
@@ -1527,7 +1530,7 @@ def deploy(
15271530
model_name=None,
15281531
kms_key=None,
15291532
data_capture_config=None,
1530-
tags=None,
1533+
tags: Optional[Tags] = None,
15311534
serverless_inference_config=None,
15321535
async_inference_config=None,
15331536
volume_size=None,
@@ -1596,8 +1599,10 @@ def deploy(
15961599
empty object passed through, will use pre-defined values in
15971600
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
15981601
instance based endpoint if it's None. (default: None)
1599-
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
1602+
tags(Optional[Tags]): Optional. Tags to attach to this specific
16001603
endpoint. Example:
1604+
>>> tags = {'tagname', 'tagvalue'}
1605+
Or
16011606
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
16021607
For more information about tags, see
16031608
https://boto3.amazonaws.com/v1/documentation\
@@ -1659,7 +1664,7 @@ def deploy(
16591664
model.name = model_name
16601665

16611666
tags = update_inference_tags_with_jumpstart_training_tags(
1662-
inference_tags=tags, training_tags=self.tags
1667+
inference_tags=format_tags(tags), training_tags=self.tags
16631668
)
16641669

16651670
return model.deploy(
@@ -2007,7 +2012,7 @@ def transformer(
20072012
env=None,
20082013
max_concurrent_transforms=None,
20092014
max_payload=None,
2010-
tags=None,
2015+
tags: Optional[Tags] = None,
20112016
role=None,
20122017
volume_kms_key=None,
20132018
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
@@ -2041,7 +2046,7 @@ def transformer(
20412046
to be made to each individual transform container at one time.
20422047
max_payload (int): Maximum size of the payload in a single HTTP
20432048
request to the container in MB.
2044-
tags (list[dict]): List of tags for labeling a transform job. If
2049+
tags (Optional[Tags]): Tags for labeling a transform job. If
20452050
none specified, then the tags used for the training job are used
20462051
for the transform job.
20472052
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2068,7 +2073,7 @@ def transformer(
20682073
model. If not specified, the estimator generates a default job name
20692074
based on the training image name and current timestamp.
20702075
"""
2071-
tags = tags or self.tags
2076+
tags = format_tags(tags) or self.tags
20722077
model_name = self._get_or_create_name(model_name)
20732078

20742079
if self.latest_training_job is None:
@@ -2661,7 +2666,7 @@ def __init__(
26612666
base_job_name: Optional[str] = None,
26622667
sagemaker_session: Optional[Session] = None,
26632668
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2664-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2669+
tags: Optional[Tags] = None,
26652670
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
26662671
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
26672672
model_uri: Optional[str] = None,
@@ -2790,7 +2795,7 @@ def __init__(
27902795
hyperparameters. SageMaker rejects the training job request and returns an
27912796
validation error for detected credentials, if such user input is found.
27922797
2793-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
2798+
tags (Optional[Tags]): Tags for
27942799
labeling a training job. For more, see
27952800
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
27962801
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3071,7 +3076,7 @@ def __init__(
30713076
output_kms_key,
30723077
base_job_name,
30733078
sagemaker_session,
3074-
tags,
3079+
format_tags(tags),
30753080
subnets,
30763081
security_group_ids,
30773082
model_uri=model_uri,
@@ -3702,7 +3707,7 @@ def transformer(
37023707
env=None,
37033708
max_concurrent_transforms=None,
37043709
max_payload=None,
3705-
tags=None,
3710+
tags: Optional[Tags] = None,
37063711
role=None,
37073712
model_server_workers=None,
37083713
volume_kms_key=None,
@@ -3738,7 +3743,7 @@ def transformer(
37383743
to be made to each individual transform container at one time.
37393744
max_payload (int): Maximum size of the payload in a single HTTP
37403745
request to the container in MB.
3741-
tags (list[dict]): List of tags for labeling a transform job. If
3746+
tags (Optional[Tags]): Tags for labeling a transform job. If
37423747
none specified, then the tags used for the training job are used
37433748
for the transform job.
37443749
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3777,7 +3782,7 @@ def transformer(
37773782
SageMaker Batch Transform job.
37783783
"""
37793784
role = role or self.role
3780-
tags = tags or self.tags
3785+
tags = format_tags(tags) or self.tags
37813786
model_name = self._get_or_create_name(model_name)
37823787

37833788
if self.latest_training_job is not None:

src/sagemaker/experiments/run.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from sagemaker.utils import (
4545
get_module,
4646
unique_name_from_base,
47+
format_tags,
48+
Tags,
49+
TagsDict,
4750
)
4851

4952
from sagemaker.experiments._utils import (
@@ -97,7 +100,7 @@ def __init__(
97100
run_name: Optional[str] = None,
98101
experiment_display_name: Optional[str] = None,
99102
run_display_name: Optional[str] = None,
100-
tags: Optional[List[Dict[str, str]]] = None,
103+
tags: Optional[Tags] = None,
101104
sagemaker_session: Optional["Session"] = None,
102105
artifact_bucket: Optional[str] = None,
103106
artifact_prefix: Optional[str] = None,
@@ -152,7 +155,7 @@ def __init__(
152155
run_display_name (str): The display name of the run used in UI (default: None).
153156
This display name is used in a create run call. If a run with the
154157
specified name already exists, this display name won't take effect.
155-
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
158+
tags (Optional[Tags]): Tags to be used for all create calls,
156159
e.g. to create an experiment, a run group, etc. (default: None).
157160
sagemaker_session (sagemaker.session.Session): Session object which
158161
manages interactions with Amazon SageMaker APIs and any other
@@ -172,6 +175,8 @@ def __init__(
172175
# avoid confusion due to mis-match in casing between run name and TC name
173176
self.run_name = self.run_name.lower()
174177

178+
tags = format_tags(tags)
179+
175180
trial_component_name = Run._generate_trial_component_name(
176181
run_name=self.run_name, experiment_name=self.experiment_name
177182
)
@@ -676,11 +681,11 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
676681
)
677682

678683
@staticmethod
679-
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
684+
def _append_run_tc_label_to_tags(tags: Optional[List[TagsDict]] = None) -> list:
680685
"""Append the run trial component label to tags used to create a trial component.
681686
682687
Args:
683-
tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object.
688+
tags (List[TagsDict]): The tags supplied by users to initialize a Run object.
684689
685690
Returns:
686691
list: The updated tags with the appended run trial component label.

src/sagemaker/feature_store/feature_group.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tempfile
2929
from concurrent.futures import as_completed
3030
from concurrent.futures import ThreadPoolExecutor
31-
from typing import Sequence, List, Dict, Any, Union
31+
from typing import Optional, Sequence, List, Dict, Any, Union
3232
from urllib.parse import urlparse
3333

3434
from multiprocessing.pool import AsyncResult
@@ -65,7 +65,7 @@
6565
OnlineStoreConfigUpdate,
6666
OnlineStoreStorageTypeEnum,
6767
)
68-
from sagemaker.utils import resolve_value_from_config
68+
from sagemaker.utils import resolve_value_from_config, format_tags, Tags
6969

7070
logger = logging.getLogger(__name__)
7171

@@ -538,7 +538,7 @@ def create(
538538
disable_glue_table_creation: bool = False,
539539
data_catalog_config: DataCatalogConfig = None,
540540
description: str = None,
541-
tags: List[Dict[str, str]] = None,
541+
tags: Optional[Tags] = None,
542542
table_format: TableFormatEnum = None,
543543
online_store_storage_type: OnlineStoreStorageTypeEnum = None,
544544
) -> Dict[str, Any]:
@@ -566,7 +566,7 @@ def create(
566566
data_catalog_config (DataCatalogConfig): configuration for
567567
Metadata store (default: None).
568568
description (str): description of the FeatureGroup (default: None).
569-
tags (List[Dict[str, str]]): list of tags for labeling a FeatureGroup (default: None).
569+
tags (Optional[Tags]): Tags for labeling a FeatureGroup (default: None).
570570
table_format (TableFormatEnum): format of the offline store table (default: None).
571571
online_store_storage_type (OnlineStoreStorageTypeEnum): storage type for the
572572
online store (default: None).
@@ -602,7 +602,7 @@ def create(
602602
],
603603
role_arn=role_arn,
604604
description=description,
605-
tags=tags,
605+
tags=format_tags(tags),
606606
)
607607

608608
# online store configuration

src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.feature_store.feature_processor._enums import (
3333
FeatureProcessorPipelineExecutionStatus,
3434
)
35+
from sagemaker.utils import TagsDict
3536

3637
logger = logging.getLogger("sagemaker")
3738

@@ -175,7 +176,7 @@ def disable_rule(self, rule_name: str) -> None:
175176
self.event_bridge_rule_client.disable_rule(Name=rule_name)
176177
logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name)
177178

178-
def add_tags(self, rule_arn: str, tags: List[Dict[str, str]]) -> None:
179+
def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None:
179180
"""Adds tags to the EventBridge Rule.
180181
181182
Args:

0 commit comments

Comments
 (0)