diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index f4124fff2a..a177b93f03 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -28,6 +28,7 @@ from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.utils import format_tags, Tags from sagemaker.workflow import is_pipeline_variable @@ -58,7 +59,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, model_uri: Optional[str] = None, @@ -121,7 +122,7 @@ def __init__( interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for + tags (Union[Tags]): Tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified @@ -170,7 +171,7 @@ def __init__( output_kms_key=output_kms_key, base_job_name=base_job_name, sagemaker_session=sagemaker_session, - tags=tags, + tags=format_tags(tags), subnets=subnets, security_group_ids=security_group_ids, model_uri=model_uri, @@ -391,7 +392,7 @@ def transformer( if self._is_marketplace(): transform_env = None - tags = tags or self.tags + tags = format_tags(tags) or self.tags else: raise RuntimeError("No finished training job found associated with this estimator") diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index 9a7359e12b..acee3d4d67 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from sagemaker.apiutils import _boto_functions, _utils +from sagemaker.utils import format_tags class ApiObject(object): @@ -194,13 +195,13 @@ def _set_tags(self, resource_arn=None, tags=None): Args: resource_arn (str): The arn of the Record - tags (dict): An array of Tag objects that set to Record + tags (Optional[Tags]): An array of Tag objects that set to Record Returns: A list of key, value pair objects. i.e. [{"key":"value"}] """ tag_list = self.sagemaker_session.sagemaker_client.add_tags( - ResourceArn=resource_arn, Tags=tags + ResourceArn=resource_arn, Tags=format_tags(tags) )["Tags"] return tag_list diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index ce71d50977..1413f3aa29 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -28,7 +28,7 @@ ) from sagemaker.job import _Job from sagemaker.session import Session -from sagemaker.utils import name_from_base, resolve_value_from_config +from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -127,7 +127,7 @@ def __init__( total_job_runtime_in_seconds: Optional[int] = None, job_objective: Optional[Dict[str, str]] = None, generate_candidate_definitions_only: Optional[bool] = False, - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[Tags] = None, content_type: Optional[str] = None, s3_data_type: Optional[str] = None, feature_specification_s3_uri: Optional[str] = None, @@ -167,8 +167,7 @@ def __init__( In the format of: {"MetricName": str} generate_candidate_definitions_only (bool): Whether to generates possible candidates without training the models. - tags (List[dict[str, str]]): The list of tags to attach to this - specific endpoint. + tags (Optional[Tags]): Tags to attach to this specific endpoint. content_type (str): The content type of the data from the input source. s3_data_type (str): The data type for S3 data source. Valid values: ManifestFile or S3Prefix. @@ -203,7 +202,7 @@ def __init__( self.target_attribute_name = target_attribute_name self.job_objective = job_objective self.generate_candidate_definitions_only = generate_candidate_definitions_only - self.tags = tags + self.tags = format_tags(tags) self.content_type = content_type self.s3_data_type = s3_data_type self.feature_specification_s3_uri = feature_specification_s3_uri @@ -581,7 +580,7 @@ def deploy( be selected on each ``deploy``. endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. - tags (List[dict[str, str]]): The list of tags to attach to this + tags (Optional[Tags]): The list of tags to attach to this specific endpoint. wait (bool): Whether the call should wait until the deployment of model completes (default: True). @@ -633,7 +632,7 @@ def deploy( deserializer=deserializer, endpoint_name=endpoint_name, kms_key=model_kms_key, - tags=tags, + tags=format_tags(tags), wait=wait, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 99ef6ef55f..882cfafc39 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -53,7 +53,7 @@ NumpySerializer, ) from sagemaker.session import production_variant, Session -from sagemaker.utils import name_from_base, stringify_object +from sagemaker.utils import name_from_base, stringify_object, format_tags from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME @@ -409,7 +409,7 @@ def update_endpoint( self.sagemaker_session.create_endpoint_config_from_existing( current_endpoint_config_name, new_endpoint_config_name, - new_tags=tags, + new_tags=format_tags(tags), new_kms_key=kms_key, new_data_capture_config_dict=data_capture_config_dict, new_production_variants=production_variants, diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 9421d0e419..11bc43c43a 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -33,6 +33,7 @@ from sagemaker.session import Session from sagemaker.network import NetworkConfig from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor +from sagemaker.utils import format_tags, Tags logger = logging.getLogger(__name__) @@ -1417,7 +1418,7 @@ def __init__( max_runtime_in_seconds: Optional[int] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, str]] = None, - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, job_name_prefix: Optional[str] = None, version: Optional[str] = None, @@ -1454,7 +1455,7 @@ def __init__( using the default AWS configuration chain. env (dict[str, str]): Environment variables to be passed to the processing jobs (default: None). - tags (list[dict]): List of tags to be passed to the processing job + tags (Optional[Tags]): Tags to be passed to the processing job (default: None). For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. network_config (:class:`~sagemaker.network.NetworkConfig`): @@ -1482,7 +1483,7 @@ def __init__( None, # We set method-specific job names below. sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 118a4af5a0..8308215e81 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -30,7 +30,7 @@ from sagemaker.s3_utils import s3_path_join from sagemaker.serializers import JSONSerializer, BaseSerializer from sagemaker.session import Session -from sagemaker.utils import _tmpdir, _create_or_update_code_dir +from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.estimator import Estimator from sagemaker.s3 import S3Uploader @@ -610,7 +610,7 @@ def deploy( default deserializer is set by the ``predictor_cls``. endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. - tags (List[dict[str, str]]): The list of tags to attach to this + tags (Optional[Tags]): The list of tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the @@ -651,7 +651,7 @@ def deploy( serializer=serializer, deserializer=deserializer, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 551a42ad55..f899570775 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -98,6 +98,8 @@ to_string, check_and_get_run_experiment_config, resolve_value_from_config, + format_tags, + Tags, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -144,7 +146,7 @@ def __init__( output_kms_key: Optional[Union[str, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, model_uri: Optional[str] = None, @@ -270,8 +272,8 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): - List of tags for labeling a training job. For more, see + tags (Optional[Tags]): + Tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified training job will be created without VPC config. @@ -604,6 +606,7 @@ def __init__( else: self.sagemaker_session = sagemaker_session or Session() + tags = format_tags(tags) self.tags = ( add_jumpstart_uri_tags( tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir @@ -1352,7 +1355,7 @@ def compile_model( framework=None, framework_version=None, compile_max_run=15 * 60, - tags=None, + tags: Optional[Tags] = None, target_platform_os=None, target_platform_arch=None, target_platform_accelerator=None, @@ -1378,7 +1381,7 @@ def compile_model( compile_max_run (int): Timeout in seconds for compilation (default: 15 * 60). After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its current status. - tags (list[dict]): List of tags for labeling a compilation job. For + tags (list[dict]): Tags for labeling a compilation job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. target_platform_os (str): Target Platform OS, for example: 'LINUX'. @@ -1420,7 +1423,7 @@ def compile_model( input_shape, output_path, self.role, - tags, + format_tags(tags), self._compilation_job_name(), compile_max_run, framework=framework, @@ -1532,7 +1535,7 @@ def deploy( model_name=None, kms_key=None, data_capture_config=None, - tags=None, + tags: Optional[Tags] = None, serverless_inference_config=None, async_inference_config=None, volume_size=None, @@ -1601,8 +1604,10 @@ def deploy( empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. (default: None) - tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific + tags(Optional[Tags]): Optional. Tags to attach to this specific endpoint. Example: + >>> tags = {'tagname', 'tagvalue'} + Or >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation\ @@ -1664,7 +1669,7 @@ def deploy( model.name = model_name tags = update_inference_tags_with_jumpstart_training_tags( - inference_tags=tags, training_tags=self.tags + inference_tags=format_tags(tags), training_tags=self.tags ) return model.deploy( @@ -2017,7 +2022,7 @@ def transformer( env=None, max_concurrent_transforms=None, max_payload=None, - tags=None, + tags: Optional[Tags] = None, role=None, volume_kms_key=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, @@ -2051,7 +2056,7 @@ def transformer( to be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict]): List of tags for labeling a transform job. If + tags (Optional[Tags]): Tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, @@ -2078,7 +2083,7 @@ def transformer( model. If not specified, the estimator generates a default job name based on the training image name and current timestamp. """ - tags = tags or self.tags + tags = format_tags(tags) or self.tags model_name = self._get_or_create_name(model_name) if self.latest_training_job is None: @@ -2717,7 +2722,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, model_uri: Optional[str] = None, @@ -2847,7 +2852,7 @@ def __init__( hyperparameters. SageMaker rejects the training job request and returns an validation error for detected credentials, if such user input is found. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for + tags (Optional[Tags]): Tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. subnets (list[str] or list[PipelineVariable]): List of subnet ids. @@ -3130,7 +3135,7 @@ def __init__( output_kms_key, base_job_name, sagemaker_session, - tags, + format_tags(tags), subnets, security_group_ids, model_uri=model_uri, @@ -3762,7 +3767,7 @@ def transformer( env=None, max_concurrent_transforms=None, max_payload=None, - tags=None, + tags: Optional[Tags] = None, role=None, model_server_workers=None, volume_kms_key=None, @@ -3798,7 +3803,7 @@ def transformer( to be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict]): List of tags for labeling a transform job. If + tags (Optional[Tags]): Tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, @@ -3837,7 +3842,7 @@ def transformer( SageMaker Batch Transform job. """ role = role or self.role - tags = tags or self.tags + tags = format_tags(tags) or self.tags model_name = self._get_or_create_name(model_name) if self.latest_training_job is not None: diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py index 584fbed27e..6f33fafb0f 100644 --- a/src/sagemaker/experiments/experiment.py +++ b/src/sagemaker/experiments/experiment.py @@ -20,6 +20,7 @@ from sagemaker.apiutils import _base_types from sagemaker.experiments.trial import _Trial from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utils import format_tags class Experiment(_base_types.Record): @@ -111,7 +112,7 @@ def create( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. - tags (List[Dict[str, str]]): A list of tags to associate with the experiment + tags (Optional[Tags]): A list of tags to associate with the experiment (default: None). Returns: @@ -122,7 +123,7 @@ def create( experiment_name=experiment_name, display_name=display_name, description=description, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) @@ -149,7 +150,7 @@ def _load_or_create( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. - tags (List[Dict[str, str]]): A list of tags to associate with the experiment + tags (Optional[Tags]): A list of tags to associate with the experiment (default: None). This is used only when the given `experiment_name` does not exist and a new experiment has to be created. @@ -161,7 +162,7 @@ def _load_or_create( experiment_name=experiment_name, display_name=display_name, description=description, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) except ClientError as ce: diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index bfef1191c3..6068880844 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -44,6 +44,9 @@ from sagemaker.utils import ( get_module, unique_name_from_base, + format_tags, + Tags, + TagsDict, ) from sagemaker.experiments._utils import ( @@ -97,7 +100,7 @@ def __init__( run_name: Optional[str] = None, experiment_display_name: Optional[str] = None, run_display_name: Optional[str] = None, - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[Tags] = None, sagemaker_session: Optional["Session"] = None, artifact_bucket: Optional[str] = None, artifact_prefix: Optional[str] = None, @@ -152,7 +155,7 @@ def __init__( run_display_name (str): The display name of the run used in UI (default: None). This display name is used in a create run call. If a run with the specified name already exists, this display name won't take effect. - tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + tags (Optional[Tags]): Tags to be used for all create calls, e.g. to create an experiment, a run group, etc. (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other @@ -172,6 +175,8 @@ def __init__( # avoid confusion due to mis-match in casing between run name and TC name self.run_name = self.run_name.lower() + tags = format_tags(tags) + trial_component_name = Run._generate_trial_component_name( run_name=self.run_name, experiment_name=self.experiment_name ) @@ -676,11 +681,11 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s ) @staticmethod - def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: + def _append_run_tc_label_to_tags(tags: Optional[List[TagsDict]] = None) -> list: """Append the run trial component label to tags used to create a trial component. Args: - tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object. + tags (List[TagsDict]): The tags supplied by users to initialize a Run object. Returns: list: The updated tags with the appended run trial component label. diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py index ce8deb4862..466ba39158 100644 --- a/src/sagemaker/experiments/trial.py +++ b/src/sagemaker/experiments/trial.py @@ -18,6 +18,7 @@ from sagemaker.apiutils import _base_types from sagemaker.experiments import _api_types from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utils import format_tags class _Trial(_base_types.Record): @@ -101,7 +102,7 @@ def create( trial_name: (str): Name of the Trial. display_name (str): Name of the trial that will appear in UI, such as SageMaker Studio (default: None). - tags (List[dict]): A list of tags to associate with the trial (default: None). + tags (Optional[Tags]): A list of tags to associate with the trial (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the @@ -115,7 +116,7 @@ def create( trial_name=trial_name, experiment_name=experiment_name, display_name=display_name, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) return trial @@ -259,7 +260,7 @@ def _load_or_create( display_name (str): Name of the trial that will appear in UI, such as SageMaker Studio (default: None). This is used only when the given `trial_name` does not exist and a new trial has to be created. - tags (List[dict]): A list of tags to associate with the trial (default: None). + tags (Optional[Tags]): A list of tags to associate with the trial (default: None). This is used only when the given `trial_name` does not exist and a new trial has to be created. sagemaker_session (sagemaker.session.Session): Session object which @@ -275,7 +276,7 @@ def _load_or_create( experiment_name=experiment_name, trial_name=trial_name, display_name=display_name, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) except ClientError as ce: diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py index 061948a9d2..bdd5cd0634 100644 --- a/src/sagemaker/experiments/trial_component.py +++ b/src/sagemaker/experiments/trial_component.py @@ -20,6 +20,7 @@ from sagemaker.apiutils import _base_types from sagemaker.experiments import _api_types from sagemaker.experiments._api_types import TrialComponentSearchResult +from sagemaker.utils import format_tags class _TrialComponent(_base_types.Record): @@ -191,7 +192,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se Args: trial_component_name (str): The name of the trial component. display_name (str): Display name of the trial component used by Studio (default: None). - tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + tags (Optional[Tags]): Tags to add to the trial component (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the @@ -204,7 +205,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se cls._boto_create_method, trial_component_name=trial_component_name, display_name=display_name, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) @@ -316,7 +317,7 @@ def _load_or_create( display_name (str): Display name of the trial component used by Studio (default: None). This is used only when the given `trial_component_name` does not exist and a new trial component has to be created. - tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + tags (Optional[Tags]): Tags to add to the trial component (default: None). This is used only when the given `trial_component_name` does not exist and a new trial component has to be created. sagemaker_session (sagemaker.session.Session): Session object which @@ -333,7 +334,7 @@ def _load_or_create( run_tc = _TrialComponent.create( trial_component_name=trial_component_name, display_name=display_name, - tags=tags, + tags=format_tags(tags), sagemaker_session=sagemaker_session, ) except ClientError as ce: diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 977fc302e0..0e503e192d 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -28,7 +28,7 @@ import tempfile from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor -from typing import Sequence, List, Dict, Any, Union +from typing import Optional, Sequence, List, Dict, Any, Union from urllib.parse import urlparse from multiprocessing.pool import AsyncResult @@ -65,7 +65,7 @@ OnlineStoreConfigUpdate, OnlineStoreStorageTypeEnum, ) -from sagemaker.utils import resolve_value_from_config +from sagemaker.utils import resolve_value_from_config, format_tags, Tags logger = logging.getLogger(__name__) @@ -538,7 +538,7 @@ def create( disable_glue_table_creation: bool = False, data_catalog_config: DataCatalogConfig = None, description: str = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, table_format: TableFormatEnum = None, online_store_storage_type: OnlineStoreStorageTypeEnum = None, ) -> Dict[str, Any]: @@ -566,7 +566,7 @@ def create( data_catalog_config (DataCatalogConfig): configuration for Metadata store (default: None). description (str): description of the FeatureGroup (default: None). - tags (List[Dict[str, str]]): list of tags for labeling a FeatureGroup (default: None). + tags (Optional[Tags]): Tags for labeling a FeatureGroup (default: None). table_format (TableFormatEnum): format of the offline store table (default: None). online_store_storage_type (OnlineStoreStorageTypeEnum): storage type for the online store (default: None). @@ -602,7 +602,7 @@ def create( ], role_arn=role_arn, description=description, - tags=tags, + tags=format_tags(tags), ) # online store configuration diff --git a/src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py b/src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py index 8f47a2e712..d47a37f5cb 100644 --- a/src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py +++ b/src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py @@ -32,6 +32,7 @@ from sagemaker.feature_store.feature_processor._enums import ( FeatureProcessorPipelineExecutionStatus, ) +from sagemaker.utils import TagsDict logger = logging.getLogger("sagemaker") @@ -175,7 +176,7 @@ def disable_rule(self, rule_name: str) -> None: self.event_bridge_rule_client.disable_rule(Name=rule_name) logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name) - def add_tags(self, rule_arn: str, tags: List[Dict[str, str]]) -> None: + def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None: """Adds tags to the EventBridge Rule. Args: diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index da294c89e2..efe6a85288 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -29,7 +29,7 @@ from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session -from sagemaker.utils import to_string +from sagemaker.utils import to_string, format_tags from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -255,7 +255,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. - tags (List[dict[str, str]]): The list of tags to attach to this + tags (Optional[Tags]): The list of tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the @@ -319,7 +319,7 @@ def deploy( deserializer, accelerator_type, endpoint_name, - tags, + format_tags(tags), kms_key, wait, data_capture_config, diff --git a/src/sagemaker/huggingface/processing.py b/src/sagemaker/huggingface/processing.py index 332148891f..b8721928f0 100644 --- a/src/sagemaker/huggingface/processing.py +++ b/src/sagemaker/huggingface/processing.py @@ -25,6 +25,7 @@ from sagemaker.huggingface.estimator import HuggingFace from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class HuggingFaceProcessor(FrameworkProcessor): @@ -51,7 +52,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a HuggingFace execution environment. @@ -101,7 +102,7 @@ def __init__( base_job_name, sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index e6047e9009..36a188ed55 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -37,7 +37,7 @@ is_valid_model_id, resolve_model_sagemaker_config_field, ) -from sagemaker.utils import stringify_object +from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.predictor import PredictorBase @@ -73,7 +73,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, model_uri: Optional[str] = None, @@ -225,8 +225,8 @@ def __init__( validation error for detected credentials, if such user input is found. (Default: None). - tags (Optional[Union[list[dict[str, str], list[dict[str, PipelineVariable]]]]): - List of tags for labeling a training job. For more, see + tags (Optional[Tags]): + Tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. (Default: None). subnets (Optional[Union[list[str], list[PipelineVariable]]]): List of subnet ids. @@ -535,7 +535,7 @@ def _is_valid_model_id_hook(): output_kms_key=output_kms_key, base_job_name=base_job_name, sagemaker_session=sagemaker_session, - tags=tags, + tags=format_tags(tags), subnets=subnets, security_group_ids=security_group_ids, model_uri=model_uri, @@ -728,7 +728,7 @@ def deploy( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, data_capture_config: Optional[DataCaptureConfig] = None, @@ -794,7 +794,7 @@ def deploy( endpoint_name (Optional[str]): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. (Default: None). - tags (Optional[List[dict[str, str]]]): The list of tags to attach to this + tags (Optional[Tags]): Tags to attach to this specific endpoint. (Default: None). kms_key (Optional[str]): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the @@ -1014,7 +1014,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7479c23832..7ccf57983b 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -70,7 +70,7 @@ from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig -from sagemaker.utils import name_from_base +from sagemaker.utils import name_from_base, format_tags, Tags from sagemaker.workflow.entities import PipelineVariable @@ -94,7 +94,7 @@ def get_init_kwargs( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, model_uri: Optional[str] = None, @@ -149,7 +149,7 @@ def get_init_kwargs( output_kms_key=output_kms_key, base_job_name=base_job_name, sagemaker_session=sagemaker_session, - tags=tags, + tags=format_tags(tags), subnets=subnets, security_group_ids=security_group_ids, model_uri=model_uri, @@ -253,7 +253,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, data_capture_config: Optional[DataCaptureConfig] = None, @@ -297,7 +297,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 185beefc59..64e4727116 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -56,7 +56,7 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import name_from_base +from sagemaker.utils import name_from_base, format_tags, Tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements @@ -496,7 +496,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, data_capture_config: Optional[DataCaptureConfig] = None, @@ -528,7 +528,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e921add6d7..1742f860e4 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -31,7 +31,7 @@ ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import is_valid_model_id -from sagemaker.utils import stringify_object +from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, ModelPackage, @@ -388,7 +388,7 @@ def _create_sagemaker_model( attach to an endpoint for model loading and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic Inference accelerator will be attached to the endpoint. (Default: None). - tags (List[dict[str, str]]): Optional. The list of tags to add to + tags (Optional[Tags]): Optional. The list of tags to add to the model. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation @@ -402,6 +402,8 @@ def _create_sagemaker_model( any so they are ignored. """ + tags = format_tags(tags) + # if the user inputs a model artifact uri, do not use model package arn to create # inference endpoint. if self.model_package_arn and not self._model_data_is_set: @@ -446,7 +448,7 @@ def deploy( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, data_capture_config: Optional[DataCaptureConfig] = None, @@ -502,7 +504,7 @@ def deploy( endpoint_name (Optional[str]): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. (Default: None). - tags (Optional[List[dict[str, str]]]): The list of tags to attach to this + tags (Optional[Tags]): Tags to attach to this specific endpoint. (Default: None). kms_key (Optional[str]): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the @@ -570,7 +572,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 7c06282894..21b624d7a4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,7 +15,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.utils import get_instance_type_family +from sagemaker.utils import get_instance_type_family, format_tags, Tags from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -1172,7 +1172,7 @@ def __init__( deserializer: Optional[Any] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, data_capture_config: Optional[Any] = None, @@ -1203,7 +1203,7 @@ def __init__( self.deserializer = deserializer self.accelerator_type = accelerator_type self.endpoint_name = endpoint_name - self.tags = deepcopy(tags) + self.tags = format_tags(tags) self.kms_key = kms_key self.wait = wait self.data_capture_config = data_capture_config @@ -1310,7 +1310,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Any] = None, hyperparameters: Optional[Dict[str, Union[str, Any]]] = None, - tags: Optional[List[Dict[str, Union[str, Any]]]] = None, + tags: Optional[Tags] = None, subnets: Optional[List[Union[str, Any]]] = None, security_group_ids: Optional[List[Union[str, Any]]] = None, model_uri: Optional[str] = None, @@ -1370,7 +1370,7 @@ def __init__( self.output_kms_key = output_kms_key self.base_job_name = base_job_name self.sagemaker_session = sagemaker_session - self.tags = deepcopy(tags) + self.tags = format_tags(tags) self.subnets = subnets self.security_group_ids = security_group_ids self.model_channel_name = model_channel_name @@ -1526,7 +1526,7 @@ def __init__( deserializer: Optional[Any] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, data_capture_config: Optional[Any] = None, @@ -1573,7 +1573,7 @@ def __init__( self.deserializer = deserializer self.accelerator_type = accelerator_type self.endpoint_name = endpoint_name - self.tags = deepcopy(tags) + self.tags = format_tags(tags) self.kms_key = kms_key self.wait = wait self.data_capture_config = data_capture_config diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 0003081e99..7d84caa0d4 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -41,7 +41,7 @@ ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config +from sagemaker.utils import resolve_value_from_config, TagsDict from sagemaker.workflow import is_pipeline_variable @@ -345,10 +345,10 @@ def get_jumpstart_base_name_if_jumpstart_model( def add_jumpstart_model_id_version_tags( - tags: Optional[List[Dict[str, str]]], + tags: Optional[List[TagsDict]], model_id: str, model_version: str, -) -> List[Dict[str, str]]: +) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: return tags @@ -368,12 +368,12 @@ def add_jumpstart_model_id_version_tags( def add_jumpstart_uri_tags( - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, inference_script_uri: Optional[str] = None, training_model_uri: Optional[str] = None, training_script_uri: Optional[str] = None, -) -> Optional[List[Dict[str, str]]]: +) -> Optional[List[TagsDict]]: """Add custom uri tags to JumpStart models, return the updated tags. No-op if this is not a JumpStart model related resource. diff --git a/src/sagemaker/lineage/action.py b/src/sagemaker/lineage/action.py index 9046a3ccf2..57b7fca5bc 100644 --- a/src/sagemaker/lineage/action.py +++ b/src/sagemaker/lineage/action.py @@ -21,6 +21,7 @@ from sagemaker.lineage import _api_types, _utils from sagemaker.lineage._api_types import ActionSource, ActionSummary from sagemaker.lineage.artifact import Artifact +from sagemaker.utils import format_tags from sagemaker.lineage.query import ( LineageQuery, @@ -159,12 +160,12 @@ def set_tags(self, tags=None): """Add tags to the object. Args: - tags ([{key:value}]): list of key value pairs. + tags (Optional[Tags]): list of key value pairs. Returns: list({str:str}): a list of key value pairs """ - return self._set_tags(resource_arn=self.action_arn, tags=tags) + return self._set_tags(resource_arn=self.action_arn, tags=format_tags(tags)) @classmethod def create( diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index 718344095a..e693313dbc 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -31,7 +31,7 @@ ) from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn from sagemaker.lineage.association import Association -from sagemaker.utils import get_module +from sagemaker.utils import get_module, format_tags LOGGER = logging.getLogger("sagemaker") @@ -288,12 +288,12 @@ def set_tags(self, tags=None): """Add tags to the object. Args: - tags ([{key:value}]): list of key value pairs. + tags (Optional[Tags]): list of key value pairs. Returns: list({str:str}): a list of key value pairs """ - return self._set_tags(resource_arn=self.artifact_arn, tags=tags) + return self._set_tags(resource_arn=self.artifact_arn, tags=format_tags(tags)) @classmethod def create( diff --git a/src/sagemaker/lineage/association.py b/src/sagemaker/lineage/association.py index fef79e2f8f..6ad08eb928 100644 --- a/src/sagemaker/lineage/association.py +++ b/src/sagemaker/lineage/association.py @@ -20,6 +20,7 @@ from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types from sagemaker.lineage._api_types import AssociationSummary +from sagemaker.utils import format_tags logger = logging.getLogger(__name__) @@ -95,7 +96,7 @@ def set_tags(self, tags=None): "set_tags on Association is deprecated. Use set_tags on the source or destination\ entity instead." ) - return self._set_tags(resource_arn=self.source_arn, tags=tags) + return self._set_tags(resource_arn=self.source_arn, tags=format_tags(tags)) @classmethod def create( diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index aef919e876..46d7693ecf 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -33,6 +33,7 @@ from sagemaker.lineage.artifact import Artifact from sagemaker.lineage.action import Action from sagemaker.lineage.lineage_trial_component import LineageTrialComponent +from sagemaker.utils import format_tags class Context(_base_types.Record): @@ -126,7 +127,7 @@ def set_tags(self, tags=None): Returns: list({str:str}): a list of key value pairs """ - return self._set_tags(resource_arn=self.context_arn, tags=tags) + return self._set_tags(resource_arn=self.context_arn, tags=format_tags(tags)) @classmethod def load(cls, context_name: str, sagemaker_session=None) -> "Context": diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 3eb4ab2b34..8431d8154a 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -28,7 +28,7 @@ from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host -from sagemaker.utils import DeferredError, get_config_value +from sagemaker.utils import DeferredError, get_config_value, format_tags from sagemaker.local.exceptions import StepExecutionException logger = logging.getLogger(__name__) @@ -552,7 +552,7 @@ class _LocalEndpointConfig(object): def __init__(self, config_name, production_variants, tags=None): self.name = config_name self.production_variants = production_variants - self.tags = tags + self.tags = format_tags(tags) self.creation_time = datetime.datetime.now() def describe(self): @@ -584,7 +584,7 @@ def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session self.name = endpoint_name self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name) self.production_variant = self.endpoint_config["ProductionVariants"][0] - self.tags = tags + self.tags = format_tags(tags) model_name = self.production_variant["ModelName"] self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"] diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index f09d64b9be..7d48850077 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -42,7 +42,12 @@ _LocalPipeline, ) from sagemaker.session import Session -from sagemaker.utils import get_config_value, _module_import_error, resolve_value_from_config +from sagemaker.utils import ( + get_config_value, + _module_import_error, + resolve_value_from_config, + format_tags, +) logger = logging.getLogger(__name__) @@ -336,7 +341,7 @@ def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=No """ LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig( - EndpointConfigName, ProductionVariants, Tags + EndpointConfigName, ProductionVariants, format_tags(Tags) ) def describe_endpoint(self, EndpointName): @@ -366,7 +371,12 @@ def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): Returns: """ - endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session) + endpoint = _LocalEndpoint( + EndpointName, + EndpointConfigName, + format_tags(Tags), + self.sagemaker_session, + ) LocalSagemakerClient._endpoints[EndpointName] = endpoint endpoint.serve() diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 56f68372ae..ff340b58e9 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -64,6 +64,8 @@ to_string, resolve_value_from_config, resolve_nested_dict_value_from_config, + format_tags, + Tags, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -554,7 +556,7 @@ def create( instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, ): """Create a SageMaker Model Entity @@ -571,10 +573,11 @@ def create( Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs (default: None). - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): The list of - tags to add to the model (default: None). Example:: + tags (Optional[Tags]): Tags to add to the model (default: None). Example:: tags = [{'Key': 'tagname', 'Value':'tagvalue'}] + # Or + tags = {'tagname', 'tagvalue'} For more information about tags, see `boto3 documentation >> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation @@ -843,7 +846,7 @@ def _create_sagemaker_model( model_package._create_sagemaker_model( instance_type=instance_type, accelerator_type=accelerator_type, - tags=tags, + tags=format_tags(tags), serverless_inference_config=serverless_inference_config, ) if self._base_name is None and model_package._base_name is not None: @@ -898,7 +901,7 @@ def _create_sagemaker_model( container_defs=container_def, vpc_config=self.vpc_config, enable_network_isolation=self._enable_network_isolation, - tags=tags, + tags=format_tags(tags), ) self.sagemaker_session.create_model(**create_model_args) @@ -956,7 +959,7 @@ def _edge_packaging_job_config( compilation_job_name (str): what compilation job to source the model from resource_key (str): the kms key to encrypt the disk with s3_kms_key (str): the kms key to encrypt the output with - tags (list[dict]): List of tags for labeling an edge packaging job. For + tags (Optional[Tags]): Tags for labeling an edge packaging job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: @@ -971,7 +974,7 @@ def _edge_packaging_job_config( return { "output_model_config": output_model_config, "role": role, - "tags": tags, + "tags": format_tags(tags), "model_name": model_name, "model_version": model_version, "job_name": packaging_job_name, @@ -1063,7 +1066,7 @@ def multi_version_compilation_supported( "output_model_config": output_model_config, "role": role, "stop_condition": {"MaxRuntimeInSeconds": compile_max_run}, - "tags": tags, + "tags": format_tags(tags), "job_name": job_name, } @@ -1091,7 +1094,7 @@ def package_for_edge( job_name (str): The name of the edge packaging job resource_key (str): the kms key to encrypt the disk with s3_kms_key (str): the kms key to encrypt the output with - tags (list[dict]): List of tags for labeling an edge packaging job. For + tags (Optional[Tags]): Tags for labeling an edge packaging job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. @@ -1126,7 +1129,7 @@ def package_for_edge( self._compilation_job_name, resource_key, s3_kms_key, - tags, + format_tags(tags), ) self.sagemaker_session.package_model_for_edge(**config) job_status = self.sagemaker_session.wait_for_edge_packaging_job(job_name) @@ -1169,7 +1172,7 @@ def compile( https://docs.aws.amazon.com/sagemaker/latest/dg/neo-compilation-preparing-model.html output_path (str): Specifies where to store the compiled model role (str): Execution role - tags (list[dict]): List of tags for labeling a compilation job. For + tags (Optional[Tags]): Tags for labeling a compilation job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. job_name (str): The name of the compilation job @@ -1242,7 +1245,7 @@ def compile( compile_max_run, job_name, framework, - tags, + format_tags(tags), target_platform_os, target_platform_arch, target_platform_accelerator, @@ -1342,7 +1345,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. - tags (List[dict[str, str]]): The list of tags to attach to this + tags (Optional[Tags]): Tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the @@ -1430,6 +1433,8 @@ def deploy( sagemaker_session=self.sagemaker_session, ) + tags = format_tags(tags) + if ( getattr(self.sagemaker_session, "settings", None) is not None and self.sagemaker_session.settings.include_jumpstart_tags @@ -1733,7 +1738,7 @@ def transformer( to be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict]): List of tags for labeling a transform job. If + tags (Optional[Tags]): Tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. volume_kms_key (str): Optional. KMS key ID for encrypting the volume @@ -1741,6 +1746,8 @@ def transformer( """ self._init_sagemaker_session_if_does_not_exist(instance_type) + tags = format_tags(tags) + self._create_sagemaker_model(instance_type, tags=tags) if self.enable_network_isolation(): env = None @@ -2165,7 +2172,7 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar container_def, vpc_config=self.vpc_config, enable_network_isolation=self.enable_network_isolation(), - tags=kwargs.get("tags"), + tags=format_tags(kwargs.get("tags")), ) def _ensure_base_name_if_needed(self, base_name): diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index bc572827cd..77f27b37f0 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -25,7 +25,7 @@ from sagemaker.model_monitor import model_monitoring as mm from sagemaker import image_uris, s3 from sagemaker.session import Session -from sagemaker.utils import name_from_base +from sagemaker.utils import name_from_base, format_tags from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig from sagemaker.lineage._utils import get_resource_name_from_arn @@ -81,7 +81,7 @@ def __init__( AWS services needed. If not specified, one is created using the default AWS configuration chain. env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -108,7 +108,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) self.latest_baselining_job_config = None @@ -296,7 +296,7 @@ def _build_create_job_definition_request( time, Amazon SageMaker terminates the job regardless of its current status. Default: 3600 env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -458,7 +458,7 @@ def _build_create_job_definition_request( request_dict["StoppingCondition"] = stop_condition if tags is not None: - request_dict["Tags"] = tags + request_dict["Tags"] = format_tags(tags) return request_dict diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index b949c6538b..2800082df4 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -62,6 +62,7 @@ retries, resolve_value_from_config, resolve_class_attribute_from_config, + format_tags, ) from sagemaker.lineage._utils import get_resource_name_from_arn from sagemaker.model_monitor.cron_expression_generator import CronExpressionGenerator @@ -163,7 +164,7 @@ def __init__( AWS services needed. If not specified, one is created using the default AWS configuration chain. env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -177,7 +178,7 @@ def __init__( self.max_runtime_in_seconds = max_runtime_in_seconds self.base_job_name = base_job_name self.sagemaker_session = sagemaker_session or Session() - self.tags = tags + self.tags = format_tags(tags) self.baselining_jobs = [] self.latest_baselining_job = None @@ -1738,7 +1739,7 @@ def __init__( AWS services needed. If not specified, one is created using the default AWS configuration chain. env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -1757,7 +1758,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) @@ -2685,7 +2686,7 @@ def _build_create_data_quality_job_definition_request( time, Amazon SageMaker terminates the job regardless of its current status. Default: 3600 env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -2817,7 +2818,7 @@ def _build_create_data_quality_job_definition_request( request_dict["StoppingCondition"] = stop_condition if tags is not None: - request_dict["Tags"] = tags + request_dict["Tags"] = format_tags(tags) return request_dict @@ -2871,7 +2872,7 @@ def __init__( AWS services needed. If not specified, one is created using the default AWS configuration chain. env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -2890,7 +2891,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) @@ -3462,7 +3463,7 @@ def _build_create_model_quality_job_definition_request( time, Amazon SageMaker terminates the job regardless of its current status. Default: 3600 env (dict): Environment variables to be passed to the job. - tags ([dict]): List of tags to be passed to the job. + tags (Optional[Tags]): List of tags to be passed to the job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -3594,7 +3595,7 @@ def _build_create_model_quality_job_definition_request( request_dict["StoppingCondition"] = stop_condition if tags is not None: - request_dict["Tags"] = tags + request_dict["Tags"] = format_tags(tags) return request_dict diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index b656b4c671..9c1e6ac4f4 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -23,7 +23,7 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.model import Model from sagemaker.session import Session -from sagemaker.utils import pop_out_unused_kwarg +from sagemaker.utils import pop_out_unused_kwarg, format_tags from sagemaker.workflow.entities import PipelineVariable MULTI_MODEL_CONTAINER_MODE = "MultiModel" @@ -245,6 +245,8 @@ def deploy( if instance_type == "local" and not isinstance(self.sagemaker_session, local.LocalSession): self.sagemaker_session = local.LocalSession() + tags = format_tags(tags) + container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type) self.sagemaker_session.create_model( self.name, diff --git a/src/sagemaker/mxnet/processing.py b/src/sagemaker/mxnet/processing.py index d85ab5b526..bb50de2014 100644 --- a/src/sagemaker/mxnet/processing.py +++ b/src/sagemaker/mxnet/processing.py @@ -24,6 +24,7 @@ from sagemaker.mxnet.estimator import MXNet from sagemaker.processing import FrameworkProcessor from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class MXNetProcessor(FrameworkProcessor): @@ -48,7 +49,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a managed MXNet execution environment. @@ -81,6 +82,6 @@ def __init__( base_job_name, sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index b9405f568c..a4b7feac69 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -31,6 +31,7 @@ name_from_image, update_container_with_inference_params, resolve_value_from_config, + format_tags, ) from sagemaker.transformer import Transformer from sagemaker.workflow.entities import PipelineVariable @@ -263,6 +264,8 @@ def deploy( if data_capture_config is not None: data_capture_config_dict = data_capture_config._to_request_dict() + tags = format_tags(tags) + if update_endpoint: endpoint_config_name = self.sagemaker_session.create_endpoint_config( name=self.name, @@ -516,7 +519,7 @@ def transformer( max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, - tags=tags, + tags=format_tags(tags), base_transform_job_name=self.name, volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session, diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index 1adfce4c7c..cdf9b141b3 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -22,7 +22,7 @@ from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse from sagemaker.s3 import parse_s3_url from sagemaker.session import Session -from sagemaker.utils import name_from_base, sagemaker_timestamp +from sagemaker.utils import name_from_base, sagemaker_timestamp, format_tags class AsyncPredictor: @@ -375,7 +375,7 @@ def update_endpoint( instance_type=instance_type, accelerator_type=accelerator_type, model_name=model_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, data_capture_config_dict=data_capture_config_dict, wait=wait, diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index a020abc140..7b16e3cba3 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -50,6 +50,8 @@ check_and_get_run_experiment_config, resolve_value_from_config, resolve_class_attribute_from_config, + Tags, + format_tags, ) from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable @@ -83,7 +85,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initializes a ``Processor`` instance. @@ -122,9 +124,8 @@ def __init__( one using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing jobs (default: None). - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags - to be passed to the processing job (default: None). For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + tags (Optional[Tags]): Tags to be passed to the processing job (default: None). + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. network_config (:class:`~sagemaker.network.NetworkConfig`): A :class:`~sagemaker.network.NetworkConfig` object that configures network isolation, encryption of @@ -137,7 +138,7 @@ def __init__( self.volume_size_in_gb = volume_size_in_gb self.max_runtime_in_seconds = max_runtime_in_seconds self.base_job_name = base_job_name - self.tags = tags + self.tags = format_tags(tags) self.jobs = [] self.latest_job = None @@ -515,7 +516,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initializes a ``ScriptProcessor`` instance. @@ -555,9 +556,8 @@ def __init__( one using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable])): Environment variables to be passed to the processing jobs (default: None). - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags to - be passed to the processing job (default: None). For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + tags (Optional[Tags]): Tags to be passed to the processing job (default: None). + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. network_config (:class:`~sagemaker.network.NetworkConfig`): A :class:`~sagemaker.network.NetworkConfig` object that configures network isolation, encryption of @@ -579,7 +579,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) @@ -1442,7 +1442,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initializes a ``FrameworkProcessor`` instance. @@ -1494,9 +1494,8 @@ def __init__( one using the default AWS configuration chain (default: None). env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing jobs (default: None). - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags to - be passed to the processing job (default: None). For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + tags (Optional[Tags]): Tags to be passed to the processing job (default: None). + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. network_config (:class:`~sagemaker.network.NetworkConfig`): A :class:`~sagemaker.network.NetworkConfig` object that configures network isolation, encryption of @@ -1531,7 +1530,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index 70fc96497e..e04e4ba65a 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -24,6 +24,7 @@ from sagemaker.processing import FrameworkProcessor from sagemaker.pytorch.estimator import PyTorch from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class PyTorchProcessor(FrameworkProcessor): @@ -48,7 +49,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a PyTorch execution environment. @@ -81,6 +82,6 @@ def __init__( base_job_name, sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index c4570da463..205a2adf41 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -49,7 +49,13 @@ from sagemaker import image_uris from sagemaker.remote_function.checkpoint_location import CheckpointLocation from sagemaker.session import get_execution_role, _logs_for_job, Session -from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config +from sagemaker.utils import ( + name_from_base, + _tmpdir, + resolve_value_from_config, + format_tags, + Tags, +) from sagemaker.s3 import s3_path_join, S3Uploader from sagemaker import vpc_utils from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData @@ -200,7 +206,7 @@ def __init__( sagemaker_session: Session = None, security_group_ids: List[Union[str, "PipelineVariable"]] = None, subnets: List[Union[str, "PipelineVariable"]] = None, - tags: List[Dict[str, Union[str, "PipelineVariable"]]] = None, + tags: Optional[Tags] = None, volume_kms_key: Union[str, "PipelineVariable"] = None, volume_size: Union[int, "PipelineVariable"] = 30, encrypt_inter_container_traffic: Union[bool, "PipelineVariable"] = None, @@ -362,9 +368,8 @@ def __init__( subnets (List[str, PipelineVariable]): A list of subnet IDs. Defaults to ``None`` and the job is created without VPC config. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): A list of tags - attached to the job. Defaults to ``None`` and the training job is created - without tags. + tags (Optional[Tags]): Tags attached to the job. Defaults to ``None`` + and the training job is created without tags. volume_kms_key (str, PipelineVariable): An Amazon Key Management Service (KMS) key used to encrypt an Amazon Elastic Block Storage (EBS) volume attached to the @@ -544,9 +549,8 @@ def __init__( vpc_config = vpc_utils.to_dict(subnets=_subnets, security_group_ids=_security_group_ids) self.vpc_config = vpc_utils.sanitize(vpc_config) - self.tags = self.sagemaker_session._append_sagemaker_config_tags( - [{"Key": k, "Value": v} for k, v in tags] if tags else None, REMOTE_FUNCTION_TAGS - ) + tags = format_tags(tags) + self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) @staticmethod def _get_default_image(session): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index fe0d259428..2cf7e78f41 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -131,6 +131,9 @@ resolve_nested_dict_value_from_config, update_nested_dictionary_with_values_from_config, update_list_of_dicts_with_values_from_config, + format_tags, + Tags, + TagsDict, ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings @@ -677,7 +680,7 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): ) raise - def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): + def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. To minimize the chance of duplicate tags being applied, this is intended to be used @@ -787,7 +790,7 @@ def train( # noqa: C901 called to convert them before training. stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. - tags (list[dict]): List of tags for labeling a training job. For more, see + tags (Optional[Tags]): Tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for @@ -886,7 +889,7 @@ def train( # noqa: C901 Returns: str: ARN of the training job, if it is created. """ - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) ) @@ -1369,7 +1372,7 @@ def process( jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - tags ([dict[str,str]]): A list of dictionaries containing key-value + tags (Optional[Tags]): A list of dictionaries containing key-value pairs. experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: @@ -1383,7 +1386,7 @@ def process( will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. """ - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, PROCESSING_JOB, TAGS) ) @@ -1597,7 +1600,7 @@ def create_monitoring_schedule( jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - tags ([dict[str,str]]): A list of dictionaries containing key-value + tags (Optional[Tags]): A list of dictionaries containing key-value pairs. data_analysis_start_time (str): Start time for the data analysis window for the one time monitoring schedule (NOW), e.g. "-PT1H" @@ -1717,7 +1720,7 @@ def create_monitoring_schedule( "NetworkConfig" ] = inferred_network_config_from_config - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, MONITORING_SCHEDULE, TAGS) ) @@ -2367,7 +2370,7 @@ def auto_ml( "MetricName" and "Value". generate_candidate_definitions_only (bool): Indicates whether to only generate candidate definitions. If True, AutoML.list_candidates() cannot be called. Default: False. - tags ([dict[str,str]]): A list of dictionaries containing key-value + tags (Optional[Tags]): A list of dictionaries containing key-value pairs. model_deploy_config (dict): Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment. @@ -2390,7 +2393,7 @@ def auto_ml( problem_type=problem_type, job_objective=job_objective, generate_candidate_definitions_only=generate_candidate_definitions_only, - tags=tags, + tags=format_tags(tags), model_deploy_config=model_deploy_config, ) @@ -2435,7 +2438,7 @@ def _get_auto_ml_request( "MetricName" and "Value". generate_candidate_definitions_only (bool): Indicates whether to only generate candidate definitions. If True, AutoML.list_candidates() cannot be called. Default: False. - tags ([dict[str,str]]): A list of dictionaries containing key-value + tags (Optional[Tags]): A list of dictionaries containing key-value pairs. model_deploy_config (dict): Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment. @@ -2460,7 +2463,7 @@ def _get_auto_ml_request( if problem_type is not None: auto_ml_job_request["ProblemType"] = problem_type - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB, TAGS) ) @@ -2650,7 +2653,7 @@ def compile_model( job_name (str): Name of the compilation job being created. stop_condition (dict): Defines when compilation job shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. - tags (list[dict]): List of tags for labeling a compile model job. For more, see + tags (Optional[Tags]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: @@ -2675,7 +2678,7 @@ def compile_model( if vpc_config: compilation_job_request["VpcConfig"] = vpc_config - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, COMPILATION_JOB, TAGS) ) @@ -2707,7 +2710,7 @@ def package_model_for_edge( job_name (str): Name of the edge packaging job being created. compilation_job_name (str): Name of the compilation job being created. resource_key (str): KMS key to encrypt the disk used to package the job - tags (list[dict]): List of tags for labeling a compile model job. For more, see + tags (Optional[Tags]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ role = resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self) @@ -2725,7 +2728,7 @@ def package_model_for_edge( resource_key = resolve_value_from_config( resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self ) - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, EDGE_PACKAGING_JOB, TAGS) ) @@ -2963,7 +2966,7 @@ def create_tuning_job( or training_config_list should be provided, but not both. warm_start_config (dict): Configuration defining the type of warm start and other required configurations. - tags (list[dict]): List of tags for labeling the tuning job. For more, see + tags (Optional[Tags]): List of tags for labeling the tuning job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. autotune (bool): Whether the parameter ranges or other unset settings of a tuning job should be chosen automatically. @@ -2982,7 +2985,7 @@ def create_tuning_job( training_config=training_config, training_config_list=training_config_list, warm_start_config=warm_start_config, - tags=tags, + tags=format_tags(tags), autotune=autotune, ) @@ -3015,7 +3018,7 @@ def _get_tuning_request( or training_config_list should be provided, but not both. warm_start_config (dict): Configuration defining the type of warm start and other required configurations. - tags (list[dict]): List of tags for labeling the tuning job. For more, see + tags (Optional[Tags]): List of tags for labeling the tuning job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. autotune (bool): Whether the parameter ranges or other unset settings of a tuning job should be chosen automatically. @@ -3040,7 +3043,7 @@ def _get_tuning_request( if warm_start_config is not None: tune_request["WarmStartConfig"] = warm_start_config - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) if tags is not None: tune_request["Tags"] = tags @@ -3497,7 +3500,7 @@ def transform( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. - tags (list[dict]): List of tags for labeling a transform job. + tags (Optional[Tags]): List of tags for labeling a transform job. data_processing(dict): A dictionary describing config for combining the input data and transformed data. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. @@ -3507,7 +3510,7 @@ def transform( batch_data_capture_config (BatchDataCaptureConfig): Configuration object which specifies the configurations related to the batch data capture for the transform job """ - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) ) @@ -3603,7 +3606,7 @@ def _create_model_request( request["PrimaryContainer"] = container_definition if tags: - request["Tags"] = tags + request["Tags"] = format_tags(tags) if vpc_config: request["VpcConfig"] = vpc_config @@ -3655,7 +3658,7 @@ def create_model( which is used to create more advanced container configurations, including model containers which need artifacts from S3. This field is deprecated, please use container_defs instead. - tags(List[dict[str, str]]): Optional. The list of tags to add to the model. + tags(Optional[Tags]): Optional. The list of tags to add to the model. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -3665,7 +3668,7 @@ def create_model( Returns: str: Name of the Amazon SageMaker ``Model`` created. """ - tags = _append_project_tags(tags) + tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) role = resolve_value_from_config( role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self @@ -3745,7 +3748,7 @@ def create_model_from_job( Default: use VpcConfig from training job. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. - tags(List[dict[str, str]]): Optional. The list of tags to add to the model. + tags(Optional[Tags]): Optional. The list of tags to add to the model. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: @@ -3786,7 +3789,7 @@ def create_model_from_job( primary_container, enable_network_isolation=enable_network_isolation, vpc_config=vpc_config, - tags=tags, + tags=format_tags(tags), ) def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data): @@ -4027,7 +4030,7 @@ def create_endpoint_config( accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html - tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint config. + tags(Optional[Tags]): Optional. The list of tags to add to the endpoint config. kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. data_capture_config_dict (dict): Specifies configuration related to Endpoint data @@ -4059,7 +4062,7 @@ def create_endpoint_config( """ logger.info("Creating endpoint-config with name %s", name) - tags = tags or [] + tags = format_tags(tags) or [] provided_production_variant = production_variant( model_name, instance_type, @@ -4136,7 +4139,7 @@ def create_endpoint_config_from_existing( new_config_name (str): Name of the Amazon SageMaker endpoint configuration to create. existing_config_name (str): Name of the existing Amazon SageMaker endpoint configuration. - new_tags (list[dict[str, str]]): Optional. The list of tags to add to the endpoint + new_tags (Optional[Tags]): Optional. The list of tags to add to the endpoint config. If not specified, the tags of the existing endpoint configuration are used. If any of the existing tags are reserved AWS ones (i.e. begin with "aws"), they are not carried over to the new endpoint configuration. @@ -4196,7 +4199,7 @@ def create_endpoint_config_from_existing( if "ModelName" not in pv or not pv["ModelName"]: request["ExecutionRoleArn"] = self.get_caller_identity_arn() - request_tags = new_tags or self.list_tags( + request_tags = format_tags(new_tags) or self.list_tags( existing_endpoint_config_desc["EndpointConfigArn"] ) request_tags = _append_project_tags(request_tags) @@ -4267,7 +4270,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). - tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint + tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint (default: None). Returns: @@ -4275,7 +4278,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live """ logger.info("Creating endpoint with name %s", endpoint_name) - tags = tags or [] + tags = format_tags(tags) or [] tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS) @@ -4384,7 +4387,7 @@ def create_inference_component( variant_name: str, specification: Dict[str, Any], runtime_config: Optional[Dict[str, Any]] = None, - tags: Optional[Dict[str, str]] = None, + tags: Optional[Tags] = None, wait: bool = True, ): """Create an Amazon SageMaker Inference Component. @@ -4399,8 +4402,8 @@ def create_inference_component( specification (Dict[str, Any]): The inference component specification. runtime_config (Optional[Dict[str, Any]]): Optional. The inference component runtime configuration. (Default: None). - tags (Optional[Dict[str, str]]): Optional. A list of dictionaries containing key-value - pairs. (Default: None). + tags (Optional[Tags]): Optional. Either a dictionary or a list + of dictionaries containing key-value pairs. (Default: None). wait (bool) : Optional. Wait for the inference component to finish being created before returning a value. (Default: True). @@ -4424,7 +4427,7 @@ def create_inference_component( "RuntimeConfig": runtime_config, } - tags = tags or [] + tags = format_tags(tags) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS) @@ -5182,7 +5185,7 @@ def endpoint_from_model_data( data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. - tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint + tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint (default: None). Returns: @@ -5191,8 +5194,8 @@ def endpoint_from_model_data( model_environment_vars = model_environment_vars or {} name = name or name_from_image(image_uri) model_vpc_config = vpc_utils.sanitize(model_vpc_config) - endpoint_config_tags = _append_project_tags(tags) - endpoint_tags = _append_project_tags(tags) + endpoint_config_tags = _append_project_tags(format_tags(tags)) + endpoint_tags = _append_project_tags(format_tags(tags)) endpoint_config_tags = self._append_sagemaker_config_tags( endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) ) @@ -5255,7 +5258,7 @@ def endpoint_from_production_variants( Args: name (str): The name of the ``Endpoint`` to create. production_variants (list[dict[str, str]]): The list of production variants to deploy. - tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint + tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint (default: None). kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. @@ -5340,8 +5343,8 @@ def endpoint_from_production_variants( # Use expand_role method to handle this situation. role = self.expand_role(role) config_options["ExecutionRoleArn"] = role - endpoint_config_tags = _append_project_tags(tags) - endpoint_tags = _append_project_tags(tags) + endpoint_config_tags = _append_project_tags(format_tags(tags)) + endpoint_tags = _append_project_tags(format_tags(tags)) endpoint_config_tags = self._append_sagemaker_config_tags( endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) @@ -5679,7 +5682,7 @@ def create_feature_group( online_store_config: Dict[str, str] = None, offline_store_config: Dict[str, str] = None, description: str = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, ) -> Dict[str, Any]: """Creates a FeatureGroup in the FeatureStore service. @@ -5694,11 +5697,12 @@ def create_feature_group( offline_store_config (Dict[str, str]): dict contains configuration of the feature offline store. description (str): description of the FeatureGroup. - tags (List[Dict[str, str]]): list of tags for labeling a FeatureGroup. + tags (Optional[Tags]): tags for labeling a FeatureGroup. Returns: Response dict from service. """ + tags = format_tags(tags) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) @@ -6155,7 +6159,7 @@ def _create_inference_recommendations_job_request( framework: str, sample_payload_url: str, supported_content_types: List[str], - tags: Dict[str, str], + tags: Optional[Tags], model_name: str = None, model_package_version_arn: str = None, job_duration_in_seconds: int = None, @@ -6191,8 +6195,8 @@ def _create_inference_recommendations_job_request( benchmarked by Amazon SageMaker Inference Recommender that matches your model. supported_instance_types (List[str]): A list of the instance types that are used to generate inferences in real-time. - tags (Dict[str, str]): Tags used to identify where the Inference Recommendatons Call - was made from. + tags (Optional[Tags]): Tags used to identify where + the Inference Recommendatons Call was made from. endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations to use for a job. Will be used for `Advanced` jobs. traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job. @@ -6231,7 +6235,7 @@ def _create_inference_recommendations_job_request( "InputConfig": { "ContainerConfig": containerConfig, }, - "Tags": tags, + "Tags": format_tags(tags), } request.get("InputConfig").update( @@ -6443,7 +6447,7 @@ def get_model_package_args( approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). - tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs + tags (Optional[Tags]): A list of dictionaries containing key-value pairs (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). @@ -6498,7 +6502,7 @@ def get_model_package_args( if description is not None: model_package_args["description"] = description if tags is not None: - model_package_args["tags"] = tags + model_package_args["tags"] = format_tags(tags) if customer_metadata_properties is not None: model_package_args["customer_metadata_properties"] = customer_metadata_properties if validation_specification is not None: @@ -6558,7 +6562,7 @@ def get_create_model_package_request( approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). - tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs + tags (Optional[Tags]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired @@ -6585,7 +6589,7 @@ def get_create_model_package_request( if description is not None: request_dict["ModelPackageDescription"] = description if tags is not None: - request_dict["Tags"] = tags + request_dict["Tags"] = format_tags(tags) if model_metrics: request_dict["ModelMetrics"] = model_metrics if drift_check_baselines: diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index 86d0df9113..ff209b3740 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -24,6 +24,7 @@ from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class SKLearnProcessor(ScriptProcessor): @@ -43,7 +44,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initialize an ``SKLearnProcessor`` instance. @@ -81,8 +82,7 @@ def __init__( using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing job. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags - to be passed to the processing job. + tags (Optional[Tags]): Tags to be passed to the processing job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -110,6 +110,6 @@ def __init__( base_job_name=base_job_name, sagemaker_session=session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 293b61f835..82634071cc 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -41,6 +41,7 @@ from sagemaker.session import Session from sagemaker.network import NetworkConfig from sagemaker.spark import defaults +from sagemaker.utils import format_tags, Tags from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -135,7 +136,7 @@ def __init__( SageMaker APIs and any other AWS services needed. If not specified, the processor creates one using the default AWS configuration chain. env (dict): Environment variables to be passed to the processing job. - tags ([dict]): List of tags to be passed to the processing job. + tags (Optional[Tags]): List of tags to be passed to the processing job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -168,7 +169,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) @@ -703,7 +704,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initialize an ``PySparkProcessor`` instance. @@ -747,7 +748,7 @@ def __init__( using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing job. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags to + tags (Optional[Tags]): List of tags to be passed to the processing job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of @@ -771,7 +772,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) @@ -980,7 +981,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """Initialize a ``SparkJarProcessor`` instance. @@ -1024,8 +1025,7 @@ def __init__( using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing job. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags to - be passed to the processing job. + tags (Optional[Tags]): Tags to be passed to the processing job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. @@ -1048,7 +1048,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 523b70ec38..df2bc74935 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -29,6 +29,7 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags logger = logging.getLogger("sagemaker") @@ -474,7 +475,7 @@ def transformer( each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict]): List of tags for labeling a transform job. If none specified, then + tags (Optional[Tags]): Tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. role (str): The IAM Role ARN for the ``TensorFlowModel``, which is also used during transform jobs. If not specified, the role from the Estimator is used. @@ -525,7 +526,7 @@ def transformer( max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env or {}, - tags=tags, + tags=format_tags(tags), base_transform_job_name=self.base_job_name, volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session, @@ -553,6 +554,6 @@ def transformer( env=env, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - tags=tags, + tags=format_tags(tags), volume_kms_key=volume_kms_key, ) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 375a2ea7e5..1b35afbe7c 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -27,6 +27,7 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.utils import format_tags logger = logging.getLogger(__name__) @@ -355,7 +356,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, - tags=tags, + tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, diff --git a/src/sagemaker/tensorflow/processing.py b/src/sagemaker/tensorflow/processing.py index e4495a39fd..529920a374 100644 --- a/src/sagemaker/tensorflow/processing.py +++ b/src/sagemaker/tensorflow/processing.py @@ -24,6 +24,7 @@ from sagemaker.processing import FrameworkProcessor from sagemaker.tensorflow.estimator import TensorFlow from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class TensorFlowProcessor(FrameworkProcessor): @@ -48,7 +49,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a TensorFlow execution environment. @@ -81,6 +82,6 @@ def __init__( base_job_name, sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index be8511f570..4ddbbc5451 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Union, Optional, List, Dict +from typing import Union, Optional, Dict import logging import copy import time @@ -42,6 +42,8 @@ check_and_get_run_experiment_config, resolve_value_from_config, resolve_class_attribute_from_config, + format_tags, + Tags, ) @@ -62,7 +64,7 @@ def __init__( accept: Optional[Union[str, PipelineVariable]] = None, max_concurrent_transforms: Optional[Union[int, PipelineVariable]] = None, max_payload: Optional[Union[int, PipelineVariable]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, base_transform_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, @@ -92,9 +94,9 @@ def __init__( to be made to each individual transform container at one time. max_payload (int or PipelineVariable): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for - labeling a transform job (default: None). For more, see the SageMaker API - documentation for `Tag `_. + tags (Optional[Tags]): Tags for labeling a transform job (default: None). + For more, see the SageMaker API documentation for + `Tag `_. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be set for use during the transform job (default: None). base_transform_job_name (str): Prefix for the transform job when the @@ -121,7 +123,7 @@ def __init__( self.max_concurrent_transforms = max_concurrent_transforms self.max_payload = max_payload - self.tags = tags + self.tags = format_tags(tags) self.base_transform_job_name = base_transform_job_name self._current_job_name = None diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 02f7bd8e79..571f84761f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -52,6 +52,8 @@ base_name_from_image, name_from_base, to_string, + format_tags, + Tags, ) AMAZON_ESTIMATOR_MODULE = "sagemaker" @@ -603,7 +605,7 @@ def __init__( max_jobs: Union[int, PipelineVariable] = None, max_parallel_jobs: Union[int, PipelineVariable] = 1, max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, base_tuning_job_name: Optional[str] = None, warm_start_config: Optional[WarmStartConfig] = None, strategy_config: Optional[StrategyConfig] = None, @@ -651,9 +653,8 @@ def __init__( start (default: 1). max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds that a hyperparameter tuning job can run. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for - labeling the tuning job (default: None). For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + tags (Optional[Tags]): Tags for labeling the tuning job (default: None). + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. base_tuning_job_name (str): Prefix for the hyperparameter tuning job name when the :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified, a default job name is @@ -746,7 +747,7 @@ def __init__( self.max_parallel_jobs = max_parallel_jobs self.max_runtime_in_seconds = max_runtime_in_seconds - self.tags = tags + self.tags = format_tags(tags) self.base_tuning_job_name = base_tuning_job_name self._current_job_name = None self.latest_tuning_job = None @@ -1924,7 +1925,8 @@ def create( (default: 1). max_runtime_in_seconds (int): The maximum time in seconds that a hyperparameter tuning job can run. - tags (list[dict]): List of tags for labeling the tuning job (default: None). For more, + tags (Optional[Tags]): List of tags for labeling the tuning job (default: None). + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. warm_start_config (sagemaker.tuner.WarmStartConfig): A ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start @@ -1988,7 +1990,7 @@ def create( max_jobs=max_jobs, max_parallel_jobs=max_parallel_jobs, max_runtime_in_seconds=max_runtime_in_seconds, - tags=tags, + tags=format_tags(tags), warm_start_config=warm_start_config, early_stopping_type=early_stopping_type, random_seed=random_seed, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 31850a290e..e203693f84 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,7 +25,7 @@ import tarfile import tempfile import time -from typing import Any, List, Optional, Dict +from typing import Union, Any, List, Optional, Dict import json import abc import uuid @@ -44,6 +44,7 @@ ) from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string +from sagemaker.workflow.entities import PipelineVariable ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MAX_BUCKET_PATHS_COUNT = 5 @@ -57,6 +58,9 @@ logger = logging.getLogger(__name__) +TagsDict = Dict[str, Union[str, PipelineVariable]] +Tags = Union[List[TagsDict], TagsDict] + # Use the base name of the image as the job name if the user doesn't give us one def name_from_image(image, max_length=63): @@ -1477,3 +1481,11 @@ def create_paginator_config(max_items: int = None, page_size: int = None) -> Dic "MaxItems": max_items if max_items else MAX_ITEMS, "PageSize": page_size if page_size else PAGE_SIZE, } + + +def format_tags(tags: Tags) -> List[TagsDict]: + """Process tags to turn them into the expected format for Sagemaker.""" + if isinstance(tags, dict): + return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()] + + return tags diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8177f6eed4..4993493513 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -32,7 +32,7 @@ Step, ConfigurableRetryStep, ) -from sagemaker.utils import _save_model, download_file_from_url +from sagemaker.utils import _save_model, download_file_from_url, format_tags from sagemaker.workflow.retry import RetryPolicy from sagemaker.workflow.utilities import trim_request_dict @@ -359,7 +359,7 @@ def __init__( depends on (default: None). retry_policies (List[RetryPolicy]): The list of retry policies for the current step (default: None). - tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to + tags (Optional[Tags]): A list of dictionaries containing key-value pairs used to configure the create model package request (default: None). container_def_list (list): A list of container definitions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). @@ -395,7 +395,7 @@ def __init__( self.inference_instances = inference_instances self.transform_instances = transform_instances self.model_package_group_name = model_package_group_name - self.tags = tags + self.tags = format_tags(tags) self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines self.customer_metadata_properties = customer_metadata_properties @@ -407,7 +407,6 @@ def __init__( self.image_uri = image_uri self.compile_model_family = compile_model_family self.description = description - self.tags = tags self.kwargs = kwargs self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index cb4951d6e4..793849ff93 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -22,6 +22,7 @@ from sagemaker.tensorflow import TensorFlow from sagemaker.estimator import EstimatorBase from sagemaker.processing import Processor +from sagemaker.utils import format_tags def prepare_framework(estimator, s3_operations): @@ -898,7 +899,7 @@ def transform_config_from_estimator( be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. - tags (list[dict]): List of tags for labeling a transform job. If none + tags (Optional[Tags]): List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, @@ -969,7 +970,7 @@ def transform_config_from_estimator( env, max_concurrent_transforms, max_payload, - tags, + format_tags(tags), role, model_server_workers, volume_kms_key, @@ -986,7 +987,7 @@ def transform_config_from_estimator( env, max_concurrent_transforms, max_payload, - tags, + format_tags(tags), role, volume_kms_key, ) diff --git a/src/sagemaker/workflow/check_job_config.py b/src/sagemaker/workflow/check_job_config.py index eaba149823..a8e4082c8e 100644 --- a/src/sagemaker/workflow/check_job_config.py +++ b/src/sagemaker/workflow/check_job_config.py @@ -24,6 +24,7 @@ ModelBiasMonitor, ModelExplainabilityMonitor, ) +from sagemaker.utils import format_tags class CheckJobConfig: @@ -66,7 +67,7 @@ def __init__( AWS services needed (default: None). If not specified, one is created using the default AWS configuration chain. env (dict): Environment variables to be passed to the job (default: None). - tags ([dict]): List of tags to be passed to the job (default: None). + tags (Optional[Tags]): List of tags to be passed to the job (default: None). network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets (default: None). @@ -82,7 +83,7 @@ def __init__( self.base_job_name = base_job_name self.sagemaker_session = sagemaker_session or Session() self.env = env - self.tags = tags + self.tags = format_tags(tags) self.network_config = network_config def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]: diff --git a/src/sagemaker/workflow/function_step.py b/src/sagemaker/workflow/function_step.py index 4fee8ef269..a55955b4eb 100644 --- a/src/sagemaker/workflow/function_step.py +++ b/src/sagemaker/workflow/function_step.py @@ -41,7 +41,7 @@ from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context from sagemaker.s3_utils import s3_path_join -from sagemaker.utils import unique_name_from_base_uuid4 +from sagemaker.utils import unique_name_from_base_uuid4, format_tags, Tags if TYPE_CHECKING: from sagemaker.remote_function.spark_config import SparkConfig @@ -374,7 +374,7 @@ def step( role: str = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, volume_size: Union[int, PipelineVariable] = 30, encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, @@ -513,8 +513,8 @@ def step( subnets (List[str, PipelineVariable]): A list of subnet IDs. Defaults to ``None`` and the job is created without a VPC config. - tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): A list of tags attached - to the job. Defaults to ``None`` and the training job is created without tags. + tags (Optional[Tags]): Tags attached to the job. Defaults to ``None`` + and the training job is created without tags. volume_kms_key (str, PipelineVariable): An Amazon Key Management Service (KMS) key used to encrypt an Amazon Elastic Block Storage (EBS) volume attached to the training instance. @@ -598,7 +598,7 @@ def wrapper(*args, **kwargs): role=role, security_group_ids=security_group_ids, subnets=subnets, - tags=tags, + tags=format_tags(tags), volume_kms_key=volume_kms_key, volume_size=volume_size, encrypt_inter_container_traffic=encrypt_inter_container_traffic, diff --git a/src/sagemaker/workflow/notebook_job_step.py b/src/sagemaker/workflow/notebook_job_step.py index e535457db6..8a1dd6bc53 100644 --- a/src/sagemaker/workflow/notebook_job_step.py +++ b/src/sagemaker/workflow/notebook_job_step.py @@ -45,7 +45,7 @@ from sagemaker.s3_utils import s3_path_join from sagemaker.s3 import S3Uploader -from sagemaker.utils import _tmpdir, name_from_base, resolve_value_from_config +from sagemaker.utils import _tmpdir, name_from_base, resolve_value_from_config, format_tags, Tags from sagemaker import vpc_utils from sagemaker.config.config_schema import ( @@ -93,7 +93,7 @@ def __init__( subnets: Optional[List[Union[str, PipelineVariable]]] = None, max_retry_attempts: int = 1, max_runtime_in_seconds: int = 2 * 24 * 60 * 60, - tags: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[Tags] = None, additional_dependencies: Optional[List[str]] = None, # pylint: enable=W0613 retry_policies: Optional[List[RetryPolicy]] = None, @@ -187,10 +187,9 @@ def __init__( time and max retry attempts, the run time applies to each retry. If a job does not complete in this time, its status is set to ``Failed``. Defaults to ``172800 seconds(2 days)``. - tags (dict[str, str] or dict[str, PipelineVariable]): A list of tags attached to the - job. Defaults to ``None`` and the training job is created without tags. Your tags - control how the Studio UI captures and displays the job created by - the pipeline in the following ways: + tags (Optional[Tags]): Tags attached to the job. Defaults to ``None`` and the training + job is created without tags. Your tags control how the Studio UI captures and + displays the job created by the pipeline in the following ways: * If you only attach the domain tag, then the notebook job is displayed to all user profiles and spaces. @@ -359,7 +358,7 @@ def _prepare_tags(self): This function converts the custom tags into training API required format and also attach the system tags. """ - custom_tags = [{"Key": k, "Value": v} for k, v in self.tags.items()] if self.tags else [] + custom_tags = format_tags(self.tags) or [] system_tags = [ {"Key": "sagemaker:name", "Value": self.notebook_job_name}, {"Key": "sagemaker:notebook-name", "Value": os.path.basename(self.input_notebook)}, diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index e65a2f5e05..0645e58386 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -33,7 +33,7 @@ from sagemaker.remote_function.job import JOBS_CONTAINER_ENTRYPOINT from sagemaker.s3_utils import s3_path_join from sagemaker.session import Session -from sagemaker.utils import resolve_value_from_config, retry_with_backoff +from sagemaker.utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow._event_bridge_client_helper import ( EventBridgeSchedulerHelper, @@ -130,7 +130,7 @@ def create( self, role_arn: str = None, description: str = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, parallelism_config: ParallelismConfiguration = None, ) -> Dict[str, Any]: """Creates a Pipeline in the Pipelines service. @@ -138,8 +138,7 @@ def create( Args: role_arn (str): The role arn that is assumed by the pipeline to create step artifacts. description (str): A description of the pipeline. - tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as - tags. + tags (Optional[Tags]): Tags to be passed to the pipeline. parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration that is applied to each of the executions of the pipeline. It takes precedence over the parallelism configuration of the parent pipeline. @@ -160,6 +159,7 @@ def create( if parallelism_config: logger.warning("Pipeline parallelism config is not supported in the local mode.") return self.sagemaker_session.sagemaker_client.create_pipeline(self, description) + tags = format_tags(tags) tags = _append_project_tags(tags) tags = self.sagemaker_session._append_sagemaker_config_tags(tags, PIPELINE_TAGS_PATH) kwargs = self._create_args(role_arn, description, parallelism_config) @@ -264,7 +264,7 @@ def upsert( self, role_arn: str = None, description: str = None, - tags: List[Dict[str, str]] = None, + tags: Optional[Tags] = None, parallelism_config: ParallelismConfiguration = None, ) -> Dict[str, Any]: """Creates a pipeline or updates it, if it already exists. @@ -272,8 +272,7 @@ def upsert( Args: role_arn (str): The role arn that is assumed by workflow to create step artifacts. description (str): A description of the pipeline. - tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as - tags. + tags (Optional[Tags]): Tags to be passed. parallelism_config (Optional[Config for parallel steps, Parallelism configuration that is applied to each of the executions @@ -283,6 +282,7 @@ def upsert( role_arn = resolve_value_from_config( role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) + tags = format_tags(tags) if not role_arn: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 5afac7b519..d48bf7c307 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -28,7 +28,7 @@ from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy -from sagemaker.utils import update_container_with_inference_params +from sagemaker.utils import update_container_with_inference_params, format_tags @attr.s @@ -128,7 +128,7 @@ def __init__( compile_model_family (str): The instance family for the compiled model. If specified, a compiled model is used (default: None). description (str): Model Package description (default: None). - tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note + tags (Optional[Tags]): The list of tags to attach to the model package group. Note that tags will only be applied to newly created model package groups; if the name of an existing group is passed to "model_package_group_name", tags will not be applied. @@ -163,6 +163,7 @@ def __init__( self.container_def_list = None subnets = None security_group_ids = None + tags = format_tags(tags) if estimator is not None: subnets = estimator.subnets @@ -390,6 +391,7 @@ def __init__( """ super().__init__(name=name, depends_on=depends_on) steps = [] + tags = format_tags(tags) if "entry_point" in kwargs: entry_point = kwargs.get("entry_point", None) source_dir = kwargs.get("source_dir", None) diff --git a/src/sagemaker/wrangler/processing.py b/src/sagemaker/wrangler/processing.py index fe38b670a0..3853fe8ef9 100644 --- a/src/sagemaker/wrangler/processing.py +++ b/src/sagemaker/wrangler/processing.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List +from typing import Dict, Optional from sagemaker.network import NetworkConfig from sagemaker.processing import ( @@ -23,6 +23,7 @@ ) from sagemaker import image_uris from sagemaker.session import Session +from sagemaker.utils import format_tags, Tags class DataWranglerProcessor(Processor): @@ -41,7 +42,7 @@ def __init__( base_job_name: str = None, sagemaker_session: Session = None, env: Dict[str, str] = None, - tags: List[dict] = None, + tags: Optional[Tags] = None, network_config: NetworkConfig = None, ): """Initializes a ``Processor`` instance. @@ -78,7 +79,7 @@ def __init__( one using the default AWS configuration chain. env (dict[str, str]): Environment variables to be passed to the processing jobs (default: None). - tags (list[dict]): List of tags to be passed to the processing job + tags (Optional[Tags]): Tags to be passed to the processing job (default: None). For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. network_config (:class:`~sagemaker.network.NetworkConfig`): @@ -103,7 +104,7 @@ def __init__( base_job_name=base_job_name, sagemaker_session=sagemaker_session, env=env, - tags=tags, + tags=format_tags(tags), network_config=network_config, ) diff --git a/src/sagemaker/xgboost/processing.py b/src/sagemaker/xgboost/processing.py index d840bfd960..1df32df37a 100644 --- a/src/sagemaker/xgboost/processing.py +++ b/src/sagemaker/xgboost/processing.py @@ -24,6 +24,7 @@ from sagemaker.processing import FrameworkProcessor from sagemaker.xgboost.estimator import XGBoost from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import format_tags, Tags class XGBoostProcessor(FrameworkProcessor): @@ -48,7 +49,7 @@ def __init__( base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, - tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + tags: Optional[Tags] = None, network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in an XGBoost execution environment. @@ -81,6 +82,6 @@ def __init__( base_job_name, sagemaker_session, env, - tags, + format_tags(tags), network_config, ) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index bfd5af977d..de86fcf99a 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -361,7 +361,7 @@ def test_create_sagemaker_model_tags(prepare_container_def, sagemaker_session): model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] model._create_sagemaker_model(INSTANCE_TYPE, tags=tags) sagemaker_session.create_model.assert_called_with( diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 8be561030e..def7ddf5e3 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -201,7 +201,7 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): env_key = "env_key" env_value = "env_value" environment = {env_key: env_value} - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] model_package = ModelPackage( role="role", @@ -314,7 +314,7 @@ def test_model_package_create_transformer_with_product_id(sagemaker_session): @patch("sagemaker.model.ModelPackage.update_approval_status") def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_session): - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] model_package = ModelPackage( role="role", model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 6654a04202..d6eaf74012 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -324,7 +324,7 @@ def test_transformer_creation_with_optional_args( env = {"foo": "bar"} max_concurrent_transforms = 3 max_payload = 100 - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] new_role = "role" vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]} model_name = "model-name" diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 003c57ac04..1ee9babdf7 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -307,7 +307,7 @@ def test_update_endpoint_all_args(name_from_base, production_variant): new_instance_type = "ml.c4.xlarge" new_accelerator_type = "ml.eia1.medium" new_model_name = "new-model" - new_tags = {"Key": "foo", "Value": "bar"} + new_tags = [{"Key": "foo", "Value": "bar"}] new_kms_key = "new-key" new_data_capture_config_dict = {} diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index 1af21a36ff..fa2d6da6c7 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -404,7 +404,7 @@ def test_update_endpoint_all_args(): new_instance_type = "ml.c4.xlarge" new_accelerator_type = "ml.eia1.medium" new_model_name = "new-model" - new_tags = {"Key": "foo", "Value": "bar"} + new_tags = [{"Key": "foo", "Value": "bar"}] new_kms_key = "new-key" new_data_capture_config_dict = {} diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d08a155c7c..c51dcaaea5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -3653,7 +3653,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_km ) -def test_create_endpoint_config_with_tags(sagemaker_session): +def test_create_endpoint_config_with_tags_list(sagemaker_session): tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] sagemaker_session.create_endpoint_config( @@ -3669,6 +3669,23 @@ def test_create_endpoint_config_with_tags(sagemaker_session): ) +def test_create_endpoint_config_with_tags_dict(sagemaker_session): + tags = {"TagtestKey": "TagtestValue"} + call_tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] + + sagemaker_session.create_endpoint_config( + name="endpoint-test", + initial_instance_count=1, + instance_type="local", + model_name="simple-model", + tags=tags, + ) + + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="endpoint-test", ProductionVariants=ANY, Tags=call_tags + ) + + def test_create_endpoint_config_with_explainer_config(sagemaker_session): explainer_config = ExplainerConfig diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 138cc3e171..8497bc7ea0 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -249,7 +249,7 @@ def test_transformer_init_optional_params(sagemaker_session): accept = "text/csv" max_concurrent_transforms = 100 max_payload = 100 - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] env = {"FOO": "BAR"} transformer = Transformer( @@ -573,7 +573,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): strategy = "MultiRecord" max_concurrent_transforms = 100 max_payload = 100 - tags = {"Key": "foo", "Value": "bar"} + tags = [{"Key": "foo", "Value": "bar"}] env = {"FOO": "BAR"} transformer = Transformer(