Skip to content

Commit 9d0ffa9

Browse files
authored
Change: More pythonic tags (#4327)
* Change: More pythonic tags * Fix broken tags * More tags formatting and add a test * Fix tests
1 parent 06b3ef0 commit 9d0ffa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+400
-306
lines changed

src/sagemaker/algorithm.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.session import Session
2929
from sagemaker.workflow.entities import PipelineVariable
3030
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
31+
from sagemaker.utils import format_tags, Tags
3132

3233
from sagemaker.workflow import is_pipeline_variable
3334

@@ -58,7 +59,7 @@ def __init__(
5859
base_job_name: Optional[str] = None,
5960
sagemaker_session: Optional[Session] = None,
6061
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
61-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
62+
tags: Optional[Tags] = None,
6263
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
6364
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
6465
model_uri: Optional[str] = None,
@@ -121,7 +122,7 @@ def __init__(
121122
interactions with Amazon SageMaker APIs and any other AWS services needed. If
122123
not specified, the estimator creates one using the default
123124
AWS configuration chain.
124-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
125+
tags (Union[Tags]): Tags for
125126
labeling a training job. For more, see
126127
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
127128
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified
@@ -170,7 +171,7 @@ def __init__(
170171
output_kms_key=output_kms_key,
171172
base_job_name=base_job_name,
172173
sagemaker_session=sagemaker_session,
173-
tags=tags,
174+
tags=format_tags(tags),
174175
subnets=subnets,
175176
security_group_ids=security_group_ids,
176177
model_uri=model_uri,
@@ -391,7 +392,7 @@ def transformer(
391392
if self._is_marketplace():
392393
transform_env = None
393394

394-
tags = tags or self.tags
395+
tags = format_tags(tags) or self.tags
395396
else:
396397
raise RuntimeError("No finished training job found associated with this estimator")
397398

src/sagemaker/apiutils/_base_types.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.apiutils import _boto_functions, _utils
17+
from sagemaker.utils import format_tags
1718

1819

1920
class ApiObject(object):
@@ -194,13 +195,13 @@ def _set_tags(self, resource_arn=None, tags=None):
194195
195196
Args:
196197
resource_arn (str): The arn of the Record
197-
tags (dict): An array of Tag objects that set to Record
198+
tags (Optional[Tags]): An array of Tag objects that set to Record
198199
199200
Returns:
200201
A list of key, value pair objects. i.e. [{"key":"value"}]
201202
"""
202203
tag_list = self.sagemaker_session.sagemaker_client.add_tags(
203-
ResourceArn=resource_arn, Tags=tags
204+
ResourceArn=resource_arn, Tags=format_tags(tags)
204205
)["Tags"]
205206
return tag_list
206207

src/sagemaker/automl/automl.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from sagemaker.job import _Job
3030
from sagemaker.session import Session
31-
from sagemaker.utils import name_from_base, resolve_value_from_config
31+
from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags
3232
from sagemaker.workflow.entities import PipelineVariable
3333
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3434

@@ -127,7 +127,7 @@ def __init__(
127127
total_job_runtime_in_seconds: Optional[int] = None,
128128
job_objective: Optional[Dict[str, str]] = None,
129129
generate_candidate_definitions_only: Optional[bool] = False,
130-
tags: Optional[List[Dict[str, str]]] = None,
130+
tags: Optional[Tags] = None,
131131
content_type: Optional[str] = None,
132132
s3_data_type: Optional[str] = None,
133133
feature_specification_s3_uri: Optional[str] = None,
@@ -167,8 +167,7 @@ def __init__(
167167
In the format of: {"MetricName": str}
168168
generate_candidate_definitions_only (bool): Whether to generates
169169
possible candidates without training the models.
170-
tags (List[dict[str, str]]): The list of tags to attach to this
171-
specific endpoint.
170+
tags (Optional[Tags]): Tags to attach to this specific endpoint.
172171
content_type (str): The content type of the data from the input source.
173172
s3_data_type (str): The data type for S3 data source.
174173
Valid values: ManifestFile or S3Prefix.
@@ -203,7 +202,7 @@ def __init__(
203202
self.target_attribute_name = target_attribute_name
204203
self.job_objective = job_objective
205204
self.generate_candidate_definitions_only = generate_candidate_definitions_only
206-
self.tags = tags
205+
self.tags = format_tags(tags)
207206
self.content_type = content_type
208207
self.s3_data_type = s3_data_type
209208
self.feature_specification_s3_uri = feature_specification_s3_uri
@@ -581,7 +580,7 @@ def deploy(
581580
be selected on each ``deploy``.
582581
endpoint_name (str): The name of the endpoint to create (default:
583582
None). If not specified, a unique endpoint name will be created.
584-
tags (List[dict[str, str]]): The list of tags to attach to this
583+
tags (Optional[Tags]): The list of tags to attach to this
585584
specific endpoint.
586585
wait (bool): Whether the call should wait until the deployment of
587586
model completes (default: True).
@@ -633,7 +632,7 @@ def deploy(
633632
deserializer=deserializer,
634633
endpoint_name=endpoint_name,
635634
kms_key=model_kms_key,
636-
tags=tags,
635+
tags=format_tags(tags),
637636
wait=wait,
638637
volume_size=volume_size,
639638
model_data_download_timeout=model_data_download_timeout,

src/sagemaker/base_predictor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
NumpySerializer,
5454
)
5555
from sagemaker.session import production_variant, Session
56-
from sagemaker.utils import name_from_base, stringify_object
56+
from sagemaker.utils import name_from_base, stringify_object, format_tags
5757

5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

@@ -409,7 +409,7 @@ def update_endpoint(
409409
self.sagemaker_session.create_endpoint_config_from_existing(
410410
current_endpoint_config_name,
411411
new_endpoint_config_name,
412-
new_tags=tags,
412+
new_tags=format_tags(tags),
413413
new_kms_key=kms_key,
414414
new_data_capture_config_dict=data_capture_config_dict,
415415
new_production_variants=production_variants,

src/sagemaker/clarify.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sagemaker.session import Session
3434
from sagemaker.network import NetworkConfig
3535
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
36+
from sagemaker.utils import format_tags, Tags
3637

3738
logger = logging.getLogger(__name__)
3839

@@ -1417,7 +1418,7 @@ def __init__(
14171418
max_runtime_in_seconds: Optional[int] = None,
14181419
sagemaker_session: Optional[Session] = None,
14191420
env: Optional[Dict[str, str]] = None,
1420-
tags: Optional[List[Dict[str, str]]] = None,
1421+
tags: Optional[Tags] = None,
14211422
network_config: Optional[NetworkConfig] = None,
14221423
job_name_prefix: Optional[str] = None,
14231424
version: Optional[str] = None,
@@ -1454,7 +1455,7 @@ def __init__(
14541455
using the default AWS configuration chain.
14551456
env (dict[str, str]): Environment variables to be passed to
14561457
the processing jobs (default: None).
1457-
tags (list[dict]): List of tags to be passed to the processing job
1458+
tags (Optional[Tags]): Tags to be passed to the processing job
14581459
(default: None). For more, see
14591460
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
14601461
network_config (:class:`~sagemaker.network.NetworkConfig`):
@@ -1482,7 +1483,7 @@ def __init__(
14821483
None, # We set method-specific job names below.
14831484
sagemaker_session,
14841485
env,
1485-
tags,
1486+
format_tags(tags),
14861487
network_config,
14871488
)
14881489

src/sagemaker/djl_inference/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sagemaker.s3_utils import s3_path_join
3131
from sagemaker.serializers import JSONSerializer, BaseSerializer
3232
from sagemaker.session import Session
33-
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
33+
from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags
3434
from sagemaker.workflow.entities import PipelineVariable
3535
from sagemaker.estimator import Estimator
3636
from sagemaker.s3 import S3Uploader
@@ -610,7 +610,7 @@ def deploy(
610610
default deserializer is set by the ``predictor_cls``.
611611
endpoint_name (str): The name of the endpoint to create (default:
612612
None). If not specified, a unique endpoint name will be created.
613-
tags (List[dict[str, str]]): The list of tags to attach to this
613+
tags (Optional[Tags]): The list of tags to attach to this
614614
specific endpoint.
615615
kms_key (str): The ARN of the KMS key that is used to encrypt the
616616
data on the storage volume attached to the instance hosting the
@@ -651,7 +651,7 @@ def deploy(
651651
serializer=serializer,
652652
deserializer=deserializer,
653653
endpoint_name=endpoint_name,
654-
tags=tags,
654+
tags=format_tags(tags),
655655
kms_key=kms_key,
656656
wait=wait,
657657
data_capture_config=data_capture_config,

src/sagemaker/estimator.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
to_string,
9999
check_and_get_run_experiment_config,
100100
resolve_value_from_config,
101+
format_tags,
102+
Tags,
101103
)
102104
from sagemaker.workflow import is_pipeline_variable
103105
from sagemaker.workflow.entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144146
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
145147
base_job_name: Optional[str] = None,
146148
sagemaker_session: Optional[Session] = None,
147-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
149+
tags: Optional[Tags] = None,
148150
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
149151
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
150152
model_uri: Optional[str] = None,
@@ -270,8 +272,8 @@ def __init__(
270272
manages interactions with Amazon SageMaker APIs and any other
271273
AWS services needed. If not specified, the estimator creates one
272274
using the default AWS configuration chain.
273-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]):
274-
List of tags for labeling a training job. For more, see
275+
tags (Optional[Tags]):
276+
Tags for labeling a training job. For more, see
275277
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
276278
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
277279
specified training job will be created without VPC config.
@@ -604,6 +606,7 @@ def __init__(
604606
else:
605607
self.sagemaker_session = sagemaker_session or Session()
606608

609+
tags = format_tags(tags)
607610
self.tags = (
608611
add_jumpstart_uri_tags(
609612
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
@@ -1352,7 +1355,7 @@ def compile_model(
13521355
framework=None,
13531356
framework_version=None,
13541357
compile_max_run=15 * 60,
1355-
tags=None,
1358+
tags: Optional[Tags] = None,
13561359
target_platform_os=None,
13571360
target_platform_arch=None,
13581361
target_platform_accelerator=None,
@@ -1378,7 +1381,7 @@ def compile_model(
13781381
compile_max_run (int): Timeout in seconds for compilation (default:
13791382
15 * 60). After this amount of time Amazon SageMaker Neo
13801383
terminates the compilation job regardless of its current status.
1381-
tags (list[dict]): List of tags for labeling a compilation job. For
1384+
tags (list[dict]): Tags for labeling a compilation job. For
13821385
more, see
13831386
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
13841387
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1420,7 +1423,7 @@ def compile_model(
14201423
input_shape,
14211424
output_path,
14221425
self.role,
1423-
tags,
1426+
format_tags(tags),
14241427
self._compilation_job_name(),
14251428
compile_max_run,
14261429
framework=framework,
@@ -1532,7 +1535,7 @@ def deploy(
15321535
model_name=None,
15331536
kms_key=None,
15341537
data_capture_config=None,
1535-
tags=None,
1538+
tags: Optional[Tags] = None,
15361539
serverless_inference_config=None,
15371540
async_inference_config=None,
15381541
volume_size=None,
@@ -1601,8 +1604,10 @@ def deploy(
16011604
empty object passed through, will use pre-defined values in
16021605
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
16031606
instance based endpoint if it's None. (default: None)
1604-
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
1607+
tags(Optional[Tags]): Optional. Tags to attach to this specific
16051608
endpoint. Example:
1609+
>>> tags = {'tagname', 'tagvalue'}
1610+
Or
16061611
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
16071612
For more information about tags, see
16081613
https://boto3.amazonaws.com/v1/documentation\
@@ -1664,7 +1669,7 @@ def deploy(
16641669
model.name = model_name
16651670

16661671
tags = update_inference_tags_with_jumpstart_training_tags(
1667-
inference_tags=tags, training_tags=self.tags
1672+
inference_tags=format_tags(tags), training_tags=self.tags
16681673
)
16691674

16701675
return model.deploy(
@@ -2017,7 +2022,7 @@ def transformer(
20172022
env=None,
20182023
max_concurrent_transforms=None,
20192024
max_payload=None,
2020-
tags=None,
2025+
tags: Optional[Tags] = None,
20212026
role=None,
20222027
volume_kms_key=None,
20232028
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
@@ -2051,7 +2056,7 @@ def transformer(
20512056
to be made to each individual transform container at one time.
20522057
max_payload (int): Maximum size of the payload in a single HTTP
20532058
request to the container in MB.
2054-
tags (list[dict]): List of tags for labeling a transform job. If
2059+
tags (Optional[Tags]): Tags for labeling a transform job. If
20552060
none specified, then the tags used for the training job are used
20562061
for the transform job.
20572062
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2078,7 +2083,7 @@ def transformer(
20782083
model. If not specified, the estimator generates a default job name
20792084
based on the training image name and current timestamp.
20802085
"""
2081-
tags = tags or self.tags
2086+
tags = format_tags(tags) or self.tags
20822087
model_name = self._get_or_create_name(model_name)
20832088

20842089
if self.latest_training_job is None:
@@ -2717,7 +2722,7 @@ def __init__(
27172722
base_job_name: Optional[str] = None,
27182723
sagemaker_session: Optional[Session] = None,
27192724
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2720-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2725+
tags: Optional[Tags] = None,
27212726
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
27222727
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
27232728
model_uri: Optional[str] = None,
@@ -2847,7 +2852,7 @@ def __init__(
28472852
hyperparameters. SageMaker rejects the training job request and returns an
28482853
validation error for detected credentials, if such user input is found.
28492854
2850-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
2855+
tags (Optional[Tags]): Tags for
28512856
labeling a training job. For more, see
28522857
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
28532858
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3130,7 +3135,7 @@ def __init__(
31303135
output_kms_key,
31313136
base_job_name,
31323137
sagemaker_session,
3133-
tags,
3138+
format_tags(tags),
31343139
subnets,
31353140
security_group_ids,
31363141
model_uri=model_uri,
@@ -3762,7 +3767,7 @@ def transformer(
37623767
env=None,
37633768
max_concurrent_transforms=None,
37643769
max_payload=None,
3765-
tags=None,
3770+
tags: Optional[Tags] = None,
37663771
role=None,
37673772
model_server_workers=None,
37683773
volume_kms_key=None,
@@ -3798,7 +3803,7 @@ def transformer(
37983803
to be made to each individual transform container at one time.
37993804
max_payload (int): Maximum size of the payload in a single HTTP
38003805
request to the container in MB.
3801-
tags (list[dict]): List of tags for labeling a transform job. If
3806+
tags (Optional[Tags]): Tags for labeling a transform job. If
38023807
none specified, then the tags used for the training job are used
38033808
for the transform job.
38043809
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3837,7 +3842,7 @@ def transformer(
38373842
SageMaker Batch Transform job.
38383843
"""
38393844
role = role or self.role
3840-
tags = tags or self.tags
3845+
tags = format_tags(tags) or self.tags
38413846
model_name = self._get_or_create_name(model_name)
38423847

38433848
if self.latest_training_job is not None:

0 commit comments

Comments
 (0)