Skip to content

Change: More pythonic tags #4327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker.session import Session
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
from sagemaker.utils import format_tags, Tags

from sagemaker.workflow import is_pipeline_variable

Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(
interactions with Amazon SageMaker APIs and any other AWS services needed. If
not specified, the estimator creates one using the default
AWS configuration chain.
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
tags (Union[Tags]): Tags for
labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(
output_kms_key=output_kms_key,
base_job_name=base_job_name,
sagemaker_session=sagemaker_session,
tags=tags,
tags=format_tags(tags),
subnets=subnets,
security_group_ids=security_group_ids,
model_uri=model_uri,
Expand Down Expand Up @@ -391,7 +392,7 @@ def transformer(
if self._is_marketplace():
transform_env = None

tags = tags or self.tags
tags = format_tags(tags) or self.tags
else:
raise RuntimeError("No finished training job found associated with this estimator")

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/apiutils/_base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from sagemaker.apiutils import _boto_functions, _utils
from sagemaker.utils import format_tags


class ApiObject(object):
Expand Down Expand Up @@ -194,13 +195,13 @@ def _set_tags(self, resource_arn=None, tags=None):

Args:
resource_arn (str): The arn of the Record
tags (dict): An array of Tag objects that set to Record
tags (Optional[Tags]): An array of Tag objects that set to Record

Returns:
A list of key, value pair objects. i.e. [{"key":"value"}]
"""
tag_list = self.sagemaker_session.sagemaker_client.add_tags(
ResourceArn=resource_arn, Tags=tags
ResourceArn=resource_arn, Tags=format_tags(tags)
)["Tags"]
return tag_list

Expand Down
13 changes: 6 additions & 7 deletions src/sagemaker/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from sagemaker.job import _Job
from sagemaker.session import Session
from sagemaker.utils import name_from_base, resolve_value_from_config
from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
total_job_runtime_in_seconds: Optional[int] = None,
job_objective: Optional[Dict[str, str]] = None,
generate_candidate_definitions_only: Optional[bool] = False,
tags: Optional[List[Dict[str, str]]] = None,
tags: Optional[Tags] = None,
content_type: Optional[str] = None,
s3_data_type: Optional[str] = None,
feature_specification_s3_uri: Optional[str] = None,
Expand Down Expand Up @@ -167,8 +167,7 @@ def __init__(
In the format of: {"MetricName": str}
generate_candidate_definitions_only (bool): Whether to generates
possible candidates without training the models.
tags (List[dict[str, str]]): The list of tags to attach to this
specific endpoint.
tags (Optional[Tags]): Tags to attach to this specific endpoint.
content_type (str): The content type of the data from the input source.
s3_data_type (str): The data type for S3 data source.
Valid values: ManifestFile or S3Prefix.
Expand Down Expand Up @@ -203,7 +202,7 @@ def __init__(
self.target_attribute_name = target_attribute_name
self.job_objective = job_objective
self.generate_candidate_definitions_only = generate_candidate_definitions_only
self.tags = tags
self.tags = format_tags(tags)
self.content_type = content_type
self.s3_data_type = s3_data_type
self.feature_specification_s3_uri = feature_specification_s3_uri
Expand Down Expand Up @@ -581,7 +580,7 @@ def deploy(
be selected on each ``deploy``.
endpoint_name (str): The name of the endpoint to create (default:
None). If not specified, a unique endpoint name will be created.
tags (List[dict[str, str]]): The list of tags to attach to this
tags (Optional[Tags]): The list of tags to attach to this
specific endpoint.
wait (bool): Whether the call should wait until the deployment of
model completes (default: True).
Expand Down Expand Up @@ -633,7 +632,7 @@ def deploy(
deserializer=deserializer,
endpoint_name=endpoint_name,
kms_key=model_kms_key,
tags=tags,
tags=format_tags(tags),
wait=wait,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
NumpySerializer,
)
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base, stringify_object
from sagemaker.utils import name_from_base, stringify_object, format_tags

from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME

Expand Down Expand Up @@ -409,7 +409,7 @@ def update_endpoint(
self.sagemaker_session.create_endpoint_config_from_existing(
current_endpoint_config_name,
new_endpoint_config_name,
new_tags=tags,
new_tags=format_tags(tags),
new_kms_key=kms_key,
new_data_capture_config_dict=data_capture_config_dict,
new_production_variants=production_variants,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sagemaker.session import Session
from sagemaker.network import NetworkConfig
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
from sagemaker.utils import format_tags, Tags

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1417,7 +1418,7 @@ def __init__(
max_runtime_in_seconds: Optional[int] = None,
sagemaker_session: Optional[Session] = None,
env: Optional[Dict[str, str]] = None,
tags: Optional[List[Dict[str, str]]] = None,
tags: Optional[Tags] = None,
network_config: Optional[NetworkConfig] = None,
job_name_prefix: Optional[str] = None,
version: Optional[str] = None,
Expand Down Expand Up @@ -1454,7 +1455,7 @@ def __init__(
using the default AWS configuration chain.
env (dict[str, str]): Environment variables to be passed to
the processing jobs (default: None).
tags (list[dict]): List of tags to be passed to the processing job
tags (Optional[Tags]): Tags to be passed to the processing job
(default: None). For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
network_config (:class:`~sagemaker.network.NetworkConfig`):
Expand Down Expand Up @@ -1482,7 +1483,7 @@ def __init__(
None, # We set method-specific job names below.
sagemaker_session,
env,
tags,
format_tags(tags),
network_config,
)

Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sagemaker.s3_utils import s3_path_join
from sagemaker.serializers import JSONSerializer, BaseSerializer
from sagemaker.session import Session
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.estimator import Estimator
from sagemaker.s3 import S3Uploader
Expand Down Expand Up @@ -610,7 +610,7 @@ def deploy(
default deserializer is set by the ``predictor_cls``.
endpoint_name (str): The name of the endpoint to create (default:
None). If not specified, a unique endpoint name will be created.
tags (List[dict[str, str]]): The list of tags to attach to this
tags (Optional[Tags]): The list of tags to attach to this
specific endpoint.
kms_key (str): The ARN of the KMS key that is used to encrypt the
data on the storage volume attached to the instance hosting the
Expand Down Expand Up @@ -651,7 +651,7 @@ def deploy(
serializer=serializer,
deserializer=deserializer,
endpoint_name=endpoint_name,
tags=tags,
tags=format_tags(tags),
kms_key=kms_key,
wait=wait,
data_capture_config=data_capture_config,
Expand Down
41 changes: 23 additions & 18 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
to_string,
check_and_get_run_experiment_config,
resolve_value_from_config,
format_tags,
Tags,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -144,7 +146,7 @@ def __init__(
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
Expand Down Expand Up @@ -270,8 +272,8 @@ def __init__(
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]):
List of tags for labeling a training job. For more, see
tags (Optional[Tags]):
Tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
specified training job will be created without VPC config.
Expand Down Expand Up @@ -604,6 +606,7 @@ def __init__(
else:
self.sagemaker_session = sagemaker_session or Session()

tags = format_tags(tags)
self.tags = (
add_jumpstart_uri_tags(
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
Expand Down Expand Up @@ -1352,7 +1355,7 @@ def compile_model(
framework=None,
framework_version=None,
compile_max_run=15 * 60,
tags=None,
tags: Optional[Tags] = None,
target_platform_os=None,
target_platform_arch=None,
target_platform_accelerator=None,
Expand All @@ -1378,7 +1381,7 @@ def compile_model(
compile_max_run (int): Timeout in seconds for compilation (default:
15 * 60). After this amount of time Amazon SageMaker Neo
terminates the compilation job regardless of its current status.
tags (list[dict]): List of tags for labeling a compilation job. For
tags (list[dict]): Tags for labeling a compilation job. For
more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
Expand Down Expand Up @@ -1420,7 +1423,7 @@ def compile_model(
input_shape,
output_path,
self.role,
tags,
format_tags(tags),
self._compilation_job_name(),
compile_max_run,
framework=framework,
Expand Down Expand Up @@ -1532,7 +1535,7 @@ def deploy(
model_name=None,
kms_key=None,
data_capture_config=None,
tags=None,
tags: Optional[Tags] = None,
serverless_inference_config=None,
async_inference_config=None,
volume_size=None,
Expand Down Expand Up @@ -1601,8 +1604,10 @@ def deploy(
empty object passed through, will use pre-defined values in
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
instance based endpoint if it's None. (default: None)
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
tags(Optional[Tags]): Optional. Tags to attach to this specific
endpoint. Example:
>>> tags = {'tagname', 'tagvalue'}
Or
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
For more information about tags, see
https://boto3.amazonaws.com/v1/documentation\
Expand Down Expand Up @@ -1664,7 +1669,7 @@ def deploy(
model.name = model_name

tags = update_inference_tags_with_jumpstart_training_tags(
inference_tags=tags, training_tags=self.tags
inference_tags=format_tags(tags), training_tags=self.tags
)

return model.deploy(
Expand Down Expand Up @@ -2017,7 +2022,7 @@ def transformer(
env=None,
max_concurrent_transforms=None,
max_payload=None,
tags=None,
tags: Optional[Tags] = None,
role=None,
volume_kms_key=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
Expand Down Expand Up @@ -2051,7 +2056,7 @@ def transformer(
to be made to each individual transform container at one time.
max_payload (int): Maximum size of the payload in a single HTTP
request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If
tags (Optional[Tags]): Tags for labeling a transform job. If
none specified, then the tags used for the training job are used
for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
Expand All @@ -2078,7 +2083,7 @@ def transformer(
model. If not specified, the estimator generates a default job name
based on the training image name and current timestamp.
"""
tags = tags or self.tags
tags = format_tags(tags) or self.tags
model_name = self._get_or_create_name(model_name)

if self.latest_training_job is None:
Expand Down Expand Up @@ -2717,7 +2722,7 @@ def __init__(
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
Expand Down Expand Up @@ -2847,7 +2852,7 @@ def __init__(
hyperparameters. SageMaker rejects the training job request and returns an
validation error for detected credentials, if such user input is found.

tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
tags (Optional[Tags]): Tags for
labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
Expand Down Expand Up @@ -3130,7 +3135,7 @@ def __init__(
output_kms_key,
base_job_name,
sagemaker_session,
tags,
format_tags(tags),
subnets,
security_group_ids,
model_uri=model_uri,
Expand Down Expand Up @@ -3762,7 +3767,7 @@ def transformer(
env=None,
max_concurrent_transforms=None,
max_payload=None,
tags=None,
tags: Optional[Tags] = None,
role=None,
model_server_workers=None,
volume_kms_key=None,
Expand Down Expand Up @@ -3798,7 +3803,7 @@ def transformer(
to be made to each individual transform container at one time.
max_payload (int): Maximum size of the payload in a single HTTP
request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If
tags (Optional[Tags]): Tags for labeling a transform job. If
none specified, then the tags used for the training job are used
for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
Expand Down Expand Up @@ -3837,7 +3842,7 @@ def transformer(
SageMaker Batch Transform job.
"""
role = role or self.role
tags = tags or self.tags
tags = format_tags(tags) or self.tags
model_name = self._get_or_create_name(model_name)

if self.latest_training_job is not None:
Expand Down
Loading