98
98
to_string ,
99
99
check_and_get_run_experiment_config ,
100
100
resolve_value_from_config ,
101
+ format_tags ,
102
+ Tags ,
101
103
)
102
104
from sagemaker .workflow import is_pipeline_variable
103
105
from sagemaker .workflow .entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144
146
output_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
145
147
base_job_name : Optional [str ] = None ,
146
148
sagemaker_session : Optional [Session ] = None ,
147
- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
149
+ tags : Optional [Tags ] = None ,
148
150
subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
149
151
security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
150
152
model_uri : Optional [str ] = None ,
@@ -269,8 +271,8 @@ def __init__(
269
271
manages interactions with Amazon SageMaker APIs and any other
270
272
AWS services needed. If not specified, the estimator creates one
271
273
using the default AWS configuration chain.
272
- tags (list[dict[str, str] or list[dict[str, PipelineVariable] ]):
273
- List of tags for labeling a training job. For more, see
274
+ tags (Optional[Tags ]):
275
+ Tags for labeling a training job. For more, see
274
276
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
275
277
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
276
278
specified training job will be created without VPC config.
@@ -601,6 +603,7 @@ def __init__(
601
603
else :
602
604
self .sagemaker_session = sagemaker_session or Session ()
603
605
606
+ tags = format_tags (tags )
604
607
self .tags = (
605
608
add_jumpstart_uri_tags (
606
609
tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
@@ -1347,7 +1350,7 @@ def compile_model(
1347
1350
framework = None ,
1348
1351
framework_version = None ,
1349
1352
compile_max_run = 15 * 60 ,
1350
- tags = None ,
1353
+ tags : Optional [ Tags ] = None ,
1351
1354
target_platform_os = None ,
1352
1355
target_platform_arch = None ,
1353
1356
target_platform_accelerator = None ,
@@ -1373,7 +1376,7 @@ def compile_model(
1373
1376
compile_max_run (int): Timeout in seconds for compilation (default:
1374
1377
15 * 60). After this amount of time Amazon SageMaker Neo
1375
1378
terminates the compilation job regardless of its current status.
1376
- tags (list[dict]): List of tags for labeling a compilation job. For
1379
+ tags (list[dict]): Tags for labeling a compilation job. For
1377
1380
more, see
1378
1381
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1379
1382
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1415,7 +1418,7 @@ def compile_model(
1415
1418
input_shape ,
1416
1419
output_path ,
1417
1420
self .role ,
1418
- tags ,
1421
+ format_tags ( tags ) ,
1419
1422
self ._compilation_job_name (),
1420
1423
compile_max_run ,
1421
1424
framework = framework ,
@@ -1527,7 +1530,7 @@ def deploy(
1527
1530
model_name = None ,
1528
1531
kms_key = None ,
1529
1532
data_capture_config = None ,
1530
- tags = None ,
1533
+ tags : Optional [ Tags ] = None ,
1531
1534
serverless_inference_config = None ,
1532
1535
async_inference_config = None ,
1533
1536
volume_size = None ,
@@ -1596,8 +1599,10 @@ def deploy(
1596
1599
empty object passed through, will use pre-defined values in
1597
1600
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
1598
1601
instance based endpoint if it's None. (default: None)
1599
- tags(List[dict[str, str]] ): Optional. The list of tags to attach to this specific
1602
+ tags(Optional[Tags] ): Optional. Tags to attach to this specific
1600
1603
endpoint. Example:
1604
+ >>> tags = {'tagname', 'tagvalue'}
1605
+ Or
1601
1606
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
1602
1607
For more information about tags, see
1603
1608
https://boto3.amazonaws.com/v1/documentation\
@@ -1659,7 +1664,7 @@ def deploy(
1659
1664
model .name = model_name
1660
1665
1661
1666
tags = update_inference_tags_with_jumpstart_training_tags (
1662
- inference_tags = tags , training_tags = self .tags
1667
+ inference_tags = format_tags ( tags ) , training_tags = self .tags
1663
1668
)
1664
1669
1665
1670
return model .deploy (
@@ -2007,7 +2012,7 @@ def transformer(
2007
2012
env = None ,
2008
2013
max_concurrent_transforms = None ,
2009
2014
max_payload = None ,
2010
- tags = None ,
2015
+ tags : Optional [ Tags ] = None ,
2011
2016
role = None ,
2012
2017
volume_kms_key = None ,
2013
2018
vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
@@ -2041,7 +2046,7 @@ def transformer(
2041
2046
to be made to each individual transform container at one time.
2042
2047
max_payload (int): Maximum size of the payload in a single HTTP
2043
2048
request to the container in MB.
2044
- tags (list[dict ]): List of tags for labeling a transform job. If
2049
+ tags (Optional[Tags ]): Tags for labeling a transform job. If
2045
2050
none specified, then the tags used for the training job are used
2046
2051
for the transform job.
2047
2052
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2068,7 +2073,7 @@ def transformer(
2068
2073
model. If not specified, the estimator generates a default job name
2069
2074
based on the training image name and current timestamp.
2070
2075
"""
2071
- tags = tags or self .tags
2076
+ tags = format_tags ( tags ) or self .tags
2072
2077
model_name = self ._get_or_create_name (model_name )
2073
2078
2074
2079
if self .latest_training_job is None :
@@ -2661,7 +2666,7 @@ def __init__(
2661
2666
base_job_name : Optional [str ] = None ,
2662
2667
sagemaker_session : Optional [Session ] = None ,
2663
2668
hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
2664
- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
2669
+ tags : Optional [Tags ] = None ,
2665
2670
subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
2666
2671
security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
2667
2672
model_uri : Optional [str ] = None ,
@@ -2790,7 +2795,7 @@ def __init__(
2790
2795
hyperparameters. SageMaker rejects the training job request and returns an
2791
2796
validation error for detected credentials, if such user input is found.
2792
2797
2793
- tags (list[dict[str, str] or list[dict[str, PipelineVariable]] ): List of tags for
2798
+ tags (Optional[Tags] ): Tags for
2794
2799
labeling a training job. For more, see
2795
2800
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2796
2801
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3071,7 +3076,7 @@ def __init__(
3071
3076
output_kms_key ,
3072
3077
base_job_name ,
3073
3078
sagemaker_session ,
3074
- tags ,
3079
+ format_tags ( tags ) ,
3075
3080
subnets ,
3076
3081
security_group_ids ,
3077
3082
model_uri = model_uri ,
@@ -3702,7 +3707,7 @@ def transformer(
3702
3707
env = None ,
3703
3708
max_concurrent_transforms = None ,
3704
3709
max_payload = None ,
3705
- tags = None ,
3710
+ tags : Optional [ Tags ] = None ,
3706
3711
role = None ,
3707
3712
model_server_workers = None ,
3708
3713
volume_kms_key = None ,
@@ -3738,7 +3743,7 @@ def transformer(
3738
3743
to be made to each individual transform container at one time.
3739
3744
max_payload (int): Maximum size of the payload in a single HTTP
3740
3745
request to the container in MB.
3741
- tags (list[dict ]): List of tags for labeling a transform job. If
3746
+ tags (Optional[Tags ]): Tags for labeling a transform job. If
3742
3747
none specified, then the tags used for the training job are used
3743
3748
for the transform job.
3744
3749
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3777,7 +3782,7 @@ def transformer(
3777
3782
SageMaker Batch Transform job.
3778
3783
"""
3779
3784
role = role or self .role
3780
- tags = tags or self .tags
3785
+ tags = format_tags ( tags ) or self .tags
3781
3786
model_name = self ._get_or_create_name (model_name )
3782
3787
3783
3788
if self .latest_training_job is not None :
0 commit comments