|
28 | 28 | )
|
29 | 29 | from sagemaker.job import _Job
|
30 | 30 | from sagemaker.session import Session
|
31 |
| -from sagemaker.utils import name_from_base, resolve_value_from_config |
| 31 | +from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags |
32 | 32 | from sagemaker.workflow.entities import PipelineVariable
|
33 | 33 | from sagemaker.workflow.pipeline_context import runnable_by_pipeline
|
34 | 34 |
|
@@ -127,7 +127,7 @@ def __init__(
|
127 | 127 | total_job_runtime_in_seconds: Optional[int] = None,
|
128 | 128 | job_objective: Optional[Dict[str, str]] = None,
|
129 | 129 | generate_candidate_definitions_only: Optional[bool] = False,
|
130 |
| - tags: Optional[List[Dict[str, str]]] = None, |
| 130 | + tags: Optional[Tags] = None, |
131 | 131 | content_type: Optional[str] = None,
|
132 | 132 | s3_data_type: Optional[str] = None,
|
133 | 133 | feature_specification_s3_uri: Optional[str] = None,
|
@@ -167,8 +167,7 @@ def __init__(
|
167 | 167 | In the format of: {"MetricName": str}
|
168 | 168 | generate_candidate_definitions_only (bool): Whether to generates
|
169 | 169 | possible candidates without training the models.
|
170 |
| - tags (List[dict[str, str]]): The list of tags to attach to this |
171 |
| - specific endpoint. |
| 170 | + tags (Optional[Tags]): Tags to attach to this specific endpoint. |
172 | 171 | content_type (str): The content type of the data from the input source.
|
173 | 172 | s3_data_type (str): The data type for S3 data source.
|
174 | 173 | Valid values: ManifestFile or S3Prefix.
|
@@ -203,7 +202,7 @@ def __init__(
|
203 | 202 | self.target_attribute_name = target_attribute_name
|
204 | 203 | self.job_objective = job_objective
|
205 | 204 | self.generate_candidate_definitions_only = generate_candidate_definitions_only
|
206 |
| - self.tags = tags |
| 205 | + self.tags = format_tags(tags) |
207 | 206 | self.content_type = content_type
|
208 | 207 | self.s3_data_type = s3_data_type
|
209 | 208 | self.feature_specification_s3_uri = feature_specification_s3_uri
|
@@ -332,7 +331,7 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None):
|
332 | 331 | total_job_runtime_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {})
|
333 | 332 | .get("CompletionCriteria", {})
|
334 | 333 | .get("MaxAutoMLJobRuntimeInSeconds"),
|
335 |
| - job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"), |
| 334 | + job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}), |
336 | 335 | generate_candidate_definitions_only=auto_ml_job_desc.get(
|
337 | 336 | "GenerateCandidateDefinitionsOnly", False
|
338 | 337 | ),
|
@@ -581,7 +580,7 @@ def deploy(
|
581 | 580 | be selected on each ``deploy``.
|
582 | 581 | endpoint_name (str): The name of the endpoint to create (default:
|
583 | 582 | None). If not specified, a unique endpoint name will be created.
|
584 |
| - tags (List[dict[str, str]]): The list of tags to attach to this |
| 583 | + tags (Optional[Tags]): The list of tags to attach to this |
585 | 584 | specific endpoint.
|
586 | 585 | wait (bool): Whether the call should wait until the deployment of
|
587 | 586 | model completes (default: True).
|
@@ -633,7 +632,7 @@ def deploy(
|
633 | 632 | deserializer=deserializer,
|
634 | 633 | endpoint_name=endpoint_name,
|
635 | 634 | kms_key=model_kms_key,
|
636 |
| - tags=tags, |
| 635 | + tags=format_tags(tags), |
637 | 636 | wait=wait,
|
638 | 637 | volume_size=volume_size,
|
639 | 638 | model_data_download_timeout=model_data_download_timeout,
|
|
0 commit comments