98
98
to_string ,
99
99
check_and_get_run_experiment_config ,
100
100
resolve_value_from_config ,
101
+ format_tags ,
102
+ Tags ,
101
103
)
102
104
from sagemaker .workflow import is_pipeline_variable
103
105
from sagemaker .workflow .entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144
146
output_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
145
147
base_job_name : Optional [str ] = None ,
146
148
sagemaker_session : Optional [Session ] = None ,
147
- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
149
+ tags : Optional [Tags ] = None ,
148
150
subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
149
151
security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
150
152
model_uri : Optional [str ] = None ,
@@ -270,8 +272,8 @@ def __init__(
270
272
manages interactions with Amazon SageMaker APIs and any other
271
273
AWS services needed. If not specified, the estimator creates one
272
274
using the default AWS configuration chain.
273
- tags (list[dict[str, str] or list[dict[str, PipelineVariable] ]):
274
- List of tags for labeling a training job. For more, see
275
+ tags (Optional[Tags ]):
276
+ Tags for labeling a training job. For more, see
275
277
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
276
278
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
277
279
specified training job will be created without VPC config.
@@ -604,6 +606,7 @@ def __init__(
604
606
else :
605
607
self .sagemaker_session = sagemaker_session or Session ()
606
608
609
+ tags = format_tags (tags )
607
610
self .tags = (
608
611
add_jumpstart_uri_tags (
609
612
tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
@@ -1352,7 +1355,7 @@ def compile_model(
1352
1355
framework = None ,
1353
1356
framework_version = None ,
1354
1357
compile_max_run = 15 * 60 ,
1355
- tags = None ,
1358
+ tags : Optional [ Tags ] = None ,
1356
1359
target_platform_os = None ,
1357
1360
target_platform_arch = None ,
1358
1361
target_platform_accelerator = None ,
@@ -1378,7 +1381,7 @@ def compile_model(
1378
1381
compile_max_run (int): Timeout in seconds for compilation (default:
1379
1382
15 * 60). After this amount of time Amazon SageMaker Neo
1380
1383
terminates the compilation job regardless of its current status.
1381
- tags (list[dict]): List of tags for labeling a compilation job. For
1384
+ tags (list[dict]): Tags for labeling a compilation job. For
1382
1385
more, see
1383
1386
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1384
1387
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1420,7 +1423,7 @@ def compile_model(
1420
1423
input_shape ,
1421
1424
output_path ,
1422
1425
self .role ,
1423
- tags ,
1426
+ format_tags ( tags ) ,
1424
1427
self ._compilation_job_name (),
1425
1428
compile_max_run ,
1426
1429
framework = framework ,
@@ -1532,7 +1535,7 @@ def deploy(
1532
1535
model_name = None ,
1533
1536
kms_key = None ,
1534
1537
data_capture_config = None ,
1535
- tags = None ,
1538
+ tags : Optional [ Tags ] = None ,
1536
1539
serverless_inference_config = None ,
1537
1540
async_inference_config = None ,
1538
1541
volume_size = None ,
@@ -1601,8 +1604,10 @@ def deploy(
1601
1604
empty object passed through, will use pre-defined values in
1602
1605
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
1603
1606
instance based endpoint if it's None. (default: None)
1604
- tags(List[dict[str, str]] ): Optional. The list of tags to attach to this specific
1607
+ tags(Optional[Tags] ): Optional. Tags to attach to this specific
1605
1608
endpoint. Example:
1609
+ >>> tags = {'tagname', 'tagvalue'}
1610
+ Or
1606
1611
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
1607
1612
For more information about tags, see
1608
1613
https://boto3.amazonaws.com/v1/documentation\
@@ -1664,7 +1669,7 @@ def deploy(
1664
1669
model .name = model_name
1665
1670
1666
1671
tags = update_inference_tags_with_jumpstart_training_tags (
1667
- inference_tags = tags , training_tags = self .tags
1672
+ inference_tags = format_tags ( tags ) , training_tags = self .tags
1668
1673
)
1669
1674
1670
1675
return model .deploy (
@@ -2017,7 +2022,7 @@ def transformer(
2017
2022
env = None ,
2018
2023
max_concurrent_transforms = None ,
2019
2024
max_payload = None ,
2020
- tags = None ,
2025
+ tags : Optional [ Tags ] = None ,
2021
2026
role = None ,
2022
2027
volume_kms_key = None ,
2023
2028
vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
@@ -2051,7 +2056,7 @@ def transformer(
2051
2056
to be made to each individual transform container at one time.
2052
2057
max_payload (int): Maximum size of the payload in a single HTTP
2053
2058
request to the container in MB.
2054
- tags (list[dict ]): List of tags for labeling a transform job. If
2059
+ tags (Optional[Tags ]): Tags for labeling a transform job. If
2055
2060
none specified, then the tags used for the training job are used
2056
2061
for the transform job.
2057
2062
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2078,7 +2083,7 @@ def transformer(
2078
2083
model. If not specified, the estimator generates a default job name
2079
2084
based on the training image name and current timestamp.
2080
2085
"""
2081
- tags = tags or self .tags
2086
+ tags = format_tags ( tags ) or self .tags
2082
2087
model_name = self ._get_or_create_name (model_name )
2083
2088
2084
2089
if self .latest_training_job is None :
@@ -2717,7 +2722,7 @@ def __init__(
2717
2722
base_job_name : Optional [str ] = None ,
2718
2723
sagemaker_session : Optional [Session ] = None ,
2719
2724
hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
2720
- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
2725
+ tags : Optional [Tags ] = None ,
2721
2726
subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
2722
2727
security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
2723
2728
model_uri : Optional [str ] = None ,
@@ -2847,7 +2852,7 @@ def __init__(
2847
2852
hyperparameters. SageMaker rejects the training job request and returns an
2848
2853
validation error for detected credentials, if such user input is found.
2849
2854
2850
- tags (list[dict[str, str] or list[dict[str, PipelineVariable]] ): List of tags for
2855
+ tags (Optional[Tags] ): Tags for
2851
2856
labeling a training job. For more, see
2852
2857
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2853
2858
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3130,7 +3135,7 @@ def __init__(
3130
3135
output_kms_key ,
3131
3136
base_job_name ,
3132
3137
sagemaker_session ,
3133
- tags ,
3138
+ format_tags ( tags ) ,
3134
3139
subnets ,
3135
3140
security_group_ids ,
3136
3141
model_uri = model_uri ,
@@ -3762,7 +3767,7 @@ def transformer(
3762
3767
env = None ,
3763
3768
max_concurrent_transforms = None ,
3764
3769
max_payload = None ,
3765
- tags = None ,
3770
+ tags : Optional [ Tags ] = None ,
3766
3771
role = None ,
3767
3772
model_server_workers = None ,
3768
3773
volume_kms_key = None ,
@@ -3798,7 +3803,7 @@ def transformer(
3798
3803
to be made to each individual transform container at one time.
3799
3804
max_payload (int): Maximum size of the payload in a single HTTP
3800
3805
request to the container in MB.
3801
- tags (list[dict ]): List of tags for labeling a transform job. If
3806
+ tags (Optional[Tags ]): Tags for labeling a transform job. If
3802
3807
none specified, then the tags used for the training job are used
3803
3808
for the transform job.
3804
3809
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3837,7 +3842,7 @@ def transformer(
3837
3842
SageMaker Batch Transform job.
3838
3843
"""
3839
3844
role = role or self .role
3840
- tags = tags or self .tags
3845
+ tags = format_tags ( tags ) or self .tags
3841
3846
model_name = self ._get_or_create_name (model_name )
3842
3847
3843
3848
if self .latest_training_job is not None :
0 commit comments