Skip to content

Commit 6c5fb56

Browse files
author
Dewen Qi
committed
change: Add Pipeline annotation in model base class and tensorflow estimator
Model annotate update
1 parent 5cf83df commit 6c5fb56

9 files changed

+177
-131
lines changed

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import uuid
2020
from abc import ABCMeta, abstractmethod
21-
from typing import Any, Dict
21+
from typing import Any, Dict, Optional, Union, List
2222

2323
from six import string_types, with_metaclass
2424
from six.moves.urllib.parse import urlparse
@@ -36,6 +36,7 @@
3636
TensorBoardOutputConfig,
3737
get_default_profiler_rule,
3838
get_rule_container_image_uri,
39+
RuleBase,
3940
)
4041
from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs
4142
from sagemaker.fw_utils import (
@@ -75,6 +76,7 @@
7576
name_from_base,
7677
)
7778
from sagemaker.workflow import is_pipeline_variable
79+
from sagemaker.workflow.entities import PipelineVariable
7880
from sagemaker.workflow.pipeline_context import (
7981
PipelineSession,
8082
runnable_by_pipeline,
@@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
105107

106108
def __init__(
107109
self,
108-
role,
109-
instance_count=None,
110-
instance_type=None,
111-
volume_size=30,
112-
volume_kms_key=None,
113-
max_run=24 * 60 * 60,
114-
input_mode="File",
115-
output_path=None,
116-
output_kms_key=None,
117-
base_job_name=None,
118-
sagemaker_session=None,
119-
tags=None,
120-
subnets=None,
121-
security_group_ids=None,
122-
model_uri=None,
123-
model_channel_name="model",
124-
metric_definitions=None,
125-
encrypt_inter_container_traffic=False,
126-
use_spot_instances=False,
127-
max_wait=None,
128-
checkpoint_s3_uri=None,
129-
checkpoint_local_path=None,
130-
rules=None,
131-
debugger_hook_config=None,
132-
tensorboard_output_config=None,
133-
enable_sagemaker_metrics=None,
134-
enable_network_isolation=False,
135-
profiler_config=None,
136-
disable_profiler=False,
137-
environment=None,
138-
max_retry_attempts=None,
139-
source_dir=None,
140-
git_config=None,
141-
hyperparameters=None,
142-
container_log_level=logging.INFO,
143-
code_location=None,
144-
entry_point=None,
145-
dependencies=None,
110+
role: str,
111+
instance_count: Optional[Union[int, PipelineVariable]] = None,
112+
instance_type: Optional[Union[str, PipelineVariable]] = None,
113+
volume_size: Union[int, PipelineVariable] = 30,
114+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
115+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
116+
input_mode: Union[str, PipelineVariable] = "File",
117+
output_path: Optional[Union[str, PipelineVariable]] = None,
118+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
119+
base_job_name: Optional[str] = None,
120+
sagemaker_session: Optional[Session] = None,
121+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
122+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
123+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
124+
model_uri: Optional[str] = None,
125+
model_channel_name: Union[str, PipelineVariable] = "model",
126+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
127+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
128+
use_spot_instances: Union[bool, PipelineVariable] = False,
129+
max_wait: Optional[Union[int, PipelineVariable]] = None,
130+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
131+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
132+
rules: Optional[List[RuleBase]] = None,
133+
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
134+
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
135+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
136+
enable_network_isolation: Union[bool, PipelineVariable] = False,
137+
profiler_config: Optional[ProfilerConfig] = None,
138+
disable_profiler: bool = False,
139+
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
140+
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
141+
source_dir: Optional[str] = None,
142+
git_config: Optional[Dict[str, str]] = None,
143+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
144+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
145+
code_location: Optional[str] = None,
146+
entry_point: Optional[str] = None,
147+
dependencies: Optional[List[Union[str]]] = None,
146148
**kwargs,
147149
):
148150
"""Initialize an ``EstimatorBase`` instance.
@@ -2730,15 +2732,25 @@ def _validate_and_set_debugger_configs(self):
27302732
# Disable debugger if checkpointing is enabled by the customer
27312733
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
27322734
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2733-
if self.instance_count > 1 or (
2735+
disable_debugger_hook_config = False
2736+
if not is_pipeline_variable(self.instance_count) and self.instance_count > 1:
2737+
disable_debugger_hook_config = True
2738+
if (
27342739
hasattr(self, "distribution")
27352740
and self.distribution is not None # pylint: disable=no-member
27362741
):
2742+
disable_debugger_hook_config = True
2743+
if disable_debugger_hook_config:
27372744
logger.info(
27382745
"SMDebug Does Not Currently Support \
27392746
Distributed Training Jobs With Checkpointing Enabled"
27402747
)
27412748
self.debugger_hook_config = False
2749+
elif is_pipeline_variable(self.instance_count):
2750+
logger.warning(
2751+
"SMDebug does not currently support when parameterized "
2752+
"instance_count with value > 1 in pipeline execution."
2753+
)
27422754

27432755
if self.debugger_hook_config is False:
27442756
if self.environment is None:

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

src/sagemaker/model.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import re
2121
import copy
22-
from typing import List, Dict
22+
from typing import List, Dict, Optional, Union
2323

2424
import sagemaker
2525
from sagemaker import (
@@ -30,15 +30,20 @@
3030
utils,
3131
git_utils,
3232
)
33+
from sagemaker.session import Session
34+
from sagemaker.model_metrics import ModelMetrics
3335
from sagemaker.deprecations import removed_kwargs
36+
from sagemaker.drift_check_baselines import DriftCheckBaselines
37+
from sagemaker.metadata_properties import MetadataProperties
3438
from sagemaker.predictor import PredictorBase
3539
from sagemaker.serverless import ServerlessInferenceConfig
3640
from sagemaker.transformer import Transformer
3741
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
38-
from sagemaker.utils import unique_name_from_base
42+
from sagemaker.utils import unique_name_from_base, to_string
3943
from sagemaker.async_inference import AsyncInferenceConfig
4044
from sagemaker.predictor_async import AsyncPredictor
4145
from sagemaker.workflow import is_pipeline_variable
46+
from sagemaker.workflow.entities import PipelineVariable
4247
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4348

4449
LOGGER = logging.getLogger("sagemaker")
@@ -78,23 +83,23 @@ class Model(ModelBase):
7883

7984
def __init__(
8085
self,
81-
image_uri,
82-
model_data=None,
83-
role=None,
84-
predictor_cls=None,
85-
env=None,
86-
name=None,
87-
vpc_config=None,
88-
sagemaker_session=None,
89-
enable_network_isolation=False,
90-
model_kms_key=None,
91-
image_config=None,
92-
source_dir=None,
93-
code_location=None,
94-
entry_point=None,
95-
container_log_level=logging.INFO,
96-
dependencies=None,
97-
git_config=None,
86+
image_uri: Union[str, PipelineVariable],
87+
model_data: Optional[Union[str, PipelineVariable]] = None,
88+
role: Optional[str] = None,
89+
predictor_cls: Optional[callable] = None,
90+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
91+
name: Optional[str] = None,
92+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
93+
sagemaker_session: Optional[Session] = None,
94+
enable_network_isolation: Union[bool, PipelineVariable] = False,
95+
model_kms_key: Optional[str] = None,
96+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
97+
source_dir: Optional[str] = None,
98+
code_location: Optional[str] = None,
99+
entry_point: Optional[str] = None,
100+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
101+
dependencies: Optional[List[str]] = None,
102+
git_config: Optional[Dict[str, str]] = None,
98103
):
99104
"""Initialize an SageMaker ``Model``.
100105
@@ -294,22 +299,22 @@ def __init__(
294299
@runnable_by_pipeline
295300
def register(
296301
self,
297-
content_types,
298-
response_types,
299-
inference_instances=None,
300-
transform_instances=None,
301-
model_package_name=None,
302-
model_package_group_name=None,
303-
image_uri=None,
304-
model_metrics=None,
305-
metadata_properties=None,
306-
marketplace_cert=False,
307-
approval_status=None,
308-
description=None,
309-
drift_check_baselines=None,
310-
customer_metadata_properties=None,
311-
validation_specification=None,
312-
domain=None,
302+
content_types: List[Union[str, PipelineVariable]],
303+
response_types: List[Union[str, PipelineVariable]],
304+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
305+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
306+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
307+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
308+
image_uri: Optional[Union[str, PipelineVariable]] = None,
309+
model_metrics: Optional[ModelMetrics] = None,
310+
metadata_properties: Optional[MetadataProperties] = None,
311+
marketplace_cert: bool = False,
312+
approval_status: Optional[Union[str, PipelineVariable]] = None,
313+
description: Optional[str] = None,
314+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
315+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
316+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
317+
domain: Optional[Union[str, PipelineVariable]] = None,
313318
):
314319
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315320
@@ -385,10 +390,10 @@ def register(
385390
@runnable_by_pipeline
386391
def create(
387392
self,
388-
instance_type: str = None,
389-
accelerator_type: str = None,
390-
serverless_inference_config: ServerlessInferenceConfig = None,
391-
tags: List[Dict[str, str]] = None,
393+
instance_type: Optional[str] = None,
394+
accelerator_type: Optional[str] = None,
395+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
396+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
392397
):
393398
"""Create a SageMaker Model Entity
394399
@@ -570,7 +575,7 @@ def _script_mode_env_vars(self):
570575
return {
571576
SCRIPT_PARAM_NAME.upper(): script_name or str(),
572577
DIR_PARAM_NAME.upper(): dir_name or str(),
573-
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
578+
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
574579
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
575580
}
576581

@@ -1239,19 +1244,19 @@ class FrameworkModel(Model):
12391244

12401245
def __init__(
12411246
self,
1242-
model_data,
1243-
image_uri,
1244-
role,
1245-
entry_point,
1246-
source_dir=None,
1247-
predictor_cls=None,
1248-
env=None,
1249-
name=None,
1250-
container_log_level=logging.INFO,
1251-
code_location=None,
1252-
sagemaker_session=None,
1253-
dependencies=None,
1254-
git_config=None,
1247+
model_data: Union[str, PipelineVariable],
1248+
image_uri: Union[str, PipelineVariable],
1249+
role: str,
1250+
entry_point: str,
1251+
source_dir: Optional[str] = None,
1252+
predictor_cls: Optional[callable] = None,
1253+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
1254+
name: Optional[str] = None,
1255+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
1256+
code_location: Optional[str] = None,
1257+
sagemaker_session: Optional[Session] = None,
1258+
dependencies: Optional[List[str]] = None,
1259+
git_config: Optional[Dict[str, str]] = None,
12551260
**kwargs,
12561261
):
12571262
"""Initialize a ``FrameworkModel``.

0 commit comments

Comments
 (0)