Skip to content

Commit 46c68a9

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
change: Add PipelineVariable annotation in estimatory, processing, tuner, transformer base classes (#3182)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 7ac80e2 commit 46c68a9

File tree

7 files changed

+291
-241
lines changed

7 files changed

+291
-241
lines changed

src/sagemaker/estimator.py

+109-94
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import uuid
2020
from abc import ABCMeta, abstractmethod
21-
from typing import Any, Dict
21+
from typing import Any, Dict, Union, Optional, List
2222

2323
from six import string_types, with_metaclass
2424
from six.moves.urllib.parse import urlparse
@@ -36,6 +36,7 @@
3636
TensorBoardOutputConfig,
3737
get_default_profiler_rule,
3838
get_rule_container_image_uri,
39+
RuleBase,
3940
)
4041
from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs
4142
from sagemaker.fw_utils import (
@@ -46,7 +47,7 @@
4647
tar_and_upload_dir,
4748
validate_source_dir,
4849
)
49-
from sagemaker.inputs import TrainingInput
50+
from sagemaker.inputs import TrainingInput, FileSystemInput
5051
from sagemaker.job import _Job
5152
from sagemaker.jumpstart.utils import (
5253
add_jumpstart_tags,
@@ -75,6 +76,7 @@
7576
name_from_base,
7677
)
7778
from sagemaker.workflow import is_pipeline_variable
79+
from sagemaker.workflow.entities import PipelineVariable
7880
from sagemaker.workflow.pipeline_context import (
7981
PipelineSession,
8082
runnable_by_pipeline,
@@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
105107

106108
def __init__(
107109
self,
108-
role,
109-
instance_count=None,
110-
instance_type=None,
111-
volume_size=30,
112-
volume_kms_key=None,
113-
max_run=24 * 60 * 60,
114-
input_mode="File",
115-
output_path=None,
116-
output_kms_key=None,
117-
base_job_name=None,
118-
sagemaker_session=None,
119-
tags=None,
120-
subnets=None,
121-
security_group_ids=None,
122-
model_uri=None,
123-
model_channel_name="model",
124-
metric_definitions=None,
125-
encrypt_inter_container_traffic=False,
126-
use_spot_instances=False,
127-
max_wait=None,
128-
checkpoint_s3_uri=None,
129-
checkpoint_local_path=None,
130-
rules=None,
131-
debugger_hook_config=None,
132-
tensorboard_output_config=None,
133-
enable_sagemaker_metrics=None,
134-
enable_network_isolation=False,
135-
profiler_config=None,
136-
disable_profiler=False,
137-
environment=None,
138-
max_retry_attempts=None,
139-
source_dir=None,
140-
git_config=None,
141-
hyperparameters=None,
142-
container_log_level=logging.INFO,
143-
code_location=None,
144-
entry_point=None,
145-
dependencies=None,
110+
role: str,
111+
instance_count: Optional[Union[int, PipelineVariable]] = None,
112+
instance_type: Optional[Union[str, PipelineVariable]] = None,
113+
volume_size: Union[int, PipelineVariable] = 30,
114+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
115+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
116+
input_mode: Union[str, PipelineVariable] = "File",
117+
output_path: Optional[Union[str, PipelineVariable]] = None,
118+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
119+
base_job_name: Optional[str] = None,
120+
sagemaker_session: Optional[Session] = None,
121+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
122+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
123+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
124+
model_uri: Optional[str] = None,
125+
model_channel_name: Union[str, PipelineVariable] = "model",
126+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
127+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
128+
use_spot_instances: Union[bool, PipelineVariable] = False,
129+
max_wait: Optional[Union[int, PipelineVariable]] = None,
130+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
131+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
132+
rules: Optional[List[RuleBase]] = None,
133+
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
134+
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
135+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
136+
enable_network_isolation: Union[bool, PipelineVariable] = False,
137+
profiler_config: Optional[ProfilerConfig] = None,
138+
disable_profiler: bool = False,
139+
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
140+
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
141+
source_dir: Optional[str] = None,
142+
git_config: Optional[Dict[str, str]] = None,
143+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
144+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
145+
code_location: Optional[str] = None,
146+
entry_point: Optional[str] = None,
147+
dependencies: Optional[List[Union[str]]] = None,
146148
**kwargs,
147149
):
148150
"""Initialize an ``EstimatorBase`` instance.
@@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self):
922924
return None
923925

924926
@runnable_by_pipeline
925-
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
927+
def fit(
928+
self,
929+
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
930+
wait: bool = True,
931+
logs: str = "All",
932+
job_name: Optional[str] = None,
933+
experiment_config: Optional[Dict[str, str]] = None,
934+
):
926935
"""Train a model using the input training dataset.
927936
928937
The API calls the Amazon SageMaker CreateTrainingJob API to start
@@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18701879
)
18711880
train_args["input_mode"] = inputs.config["InputMode"]
18721881

1882+
# enable_network_isolation may be a pipeline variable place holder object
1883+
# which is parsed in execution time
18731884
if estimator.enable_network_isolation():
1874-
train_args["enable_network_isolation"] = True
1885+
train_args["enable_network_isolation"] = estimator.enable_network_isolation()
18751886

18761887
if estimator.max_retry_attempts is not None:
18771888
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
18781889
else:
18791890
train_args["retry_strategy"] = None
18801891

1892+
# encrypt_inter_container_traffic may be a pipeline variable place holder object
1893+
# which is parsed in execution time
18811894
if estimator.encrypt_inter_container_traffic:
1882-
train_args["encrypt_inter_container_traffic"] = True
1895+
train_args[
1896+
"encrypt_inter_container_traffic"
1897+
] = estimator.encrypt_inter_container_traffic
18831898

18841899
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
18851900
train_args["algorithm_arn"] = estimator.algorithm_arn
@@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase):
20252040

20262041
def __init__(
20272042
self,
2028-
image_uri,
2029-
role,
2030-
instance_count=None,
2031-
instance_type=None,
2032-
volume_size=30,
2033-
volume_kms_key=None,
2034-
max_run=24 * 60 * 60,
2035-
input_mode="File",
2036-
output_path=None,
2037-
output_kms_key=None,
2038-
base_job_name=None,
2039-
sagemaker_session=None,
2040-
hyperparameters=None,
2041-
tags=None,
2042-
subnets=None,
2043-
security_group_ids=None,
2044-
model_uri=None,
2045-
model_channel_name="model",
2046-
metric_definitions=None,
2047-
encrypt_inter_container_traffic=False,
2048-
use_spot_instances=False,
2049-
max_wait=None,
2050-
checkpoint_s3_uri=None,
2051-
checkpoint_local_path=None,
2052-
enable_network_isolation=False,
2053-
rules=None,
2054-
debugger_hook_config=None,
2055-
tensorboard_output_config=None,
2056-
enable_sagemaker_metrics=None,
2057-
profiler_config=None,
2058-
disable_profiler=False,
2059-
environment=None,
2060-
max_retry_attempts=None,
2061-
source_dir=None,
2062-
git_config=None,
2063-
container_log_level=logging.INFO,
2064-
code_location=None,
2065-
entry_point=None,
2066-
dependencies=None,
2043+
image_uri: Union[str, PipelineVariable],
2044+
role: str,
2045+
instance_count: Optional[Union[int, PipelineVariable]] = None,
2046+
instance_type: Optional[Union[str, PipelineVariable]] = None,
2047+
volume_size: Union[int, PipelineVariable] = 30,
2048+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
2049+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
2050+
input_mode: Union[str, PipelineVariable] = "File",
2051+
output_path: Optional[Union[str, PipelineVariable]] = None,
2052+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
2053+
base_job_name: Optional[str] = None,
2054+
sagemaker_session: Optional[Session] = None,
2055+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2056+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2057+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
2058+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
2059+
model_uri: Optional[str] = None,
2060+
model_channel_name: Union[str, PipelineVariable] = "model",
2061+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2062+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
2063+
use_spot_instances: Union[bool, PipelineVariable] = False,
2064+
max_wait: Optional[Union[int, PipelineVariable]] = None,
2065+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
2066+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
2067+
enable_network_isolation: Union[bool, PipelineVariable] = False,
2068+
rules: Optional[List[RuleBase]] = None,
2069+
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
2070+
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
2071+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
2072+
profiler_config: Optional[ProfilerConfig] = None,
2073+
disable_profiler: bool = False,
2074+
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2075+
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
2076+
source_dir: Optional[str] = None,
2077+
git_config: Optional[Dict[str, str]] = None,
2078+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
2079+
code_location: Optional[str] = None,
2080+
entry_point: Optional[str] = None,
2081+
dependencies: Optional[List[str]] = None,
20672082
**kwargs,
20682083
):
20692084
"""Initialize an ``Estimator`` instance.
@@ -2488,18 +2503,18 @@ class Framework(EstimatorBase):
24882503

24892504
def __init__(
24902505
self,
2491-
entry_point,
2492-
source_dir=None,
2493-
hyperparameters=None,
2494-
container_log_level=logging.INFO,
2495-
code_location=None,
2496-
image_uri=None,
2497-
dependencies=None,
2498-
enable_network_isolation=False,
2499-
git_config=None,
2500-
checkpoint_s3_uri=None,
2501-
checkpoint_local_path=None,
2502-
enable_sagemaker_metrics=None,
2506+
entry_point: str,
2507+
source_dir: Optional[str] = None,
2508+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2509+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
2510+
code_location: Optional[str] = None,
2511+
image_uri: Optional[Union[str, PipelineVariable]] = None,
2512+
dependencies: Optional[List[str]] = None,
2513+
enable_network_isolation: Union[bool, PipelineVariable] = False,
2514+
git_config: Optional[Dict[str, str]] = None,
2515+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
2516+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
2517+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
25032518
**kwargs,
25042519
):
25052520
"""Base class initializer.

src/sagemaker/network.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
"""
1717
from __future__ import absolute_import
1818

19+
from typing import Union, Optional, List
20+
21+
from sagemaker.workflow.entities import PipelineVariable
22+
1923

2024
class NetworkConfig(object):
2125
"""Accepts network configuration parameters for conversion to request dict.
@@ -25,10 +29,10 @@ class NetworkConfig(object):
2529

2630
def __init__(
2731
self,
28-
enable_network_isolation=False,
29-
security_group_ids=None,
30-
subnets=None,
31-
encrypt_inter_container_traffic=None,
32+
enable_network_isolation: Union[bool, PipelineVariable] = False,
33+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
34+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
35+
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
3236
):
3337
"""Initialize a ``NetworkConfig`` instance.
3438

src/sagemaker/parameter.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import json
17+
from typing import Union
1718

1819
from sagemaker.workflow import is_pipeline_variable
20+
from sagemaker.workflow.entities import PipelineVariable
1921

2022

2123
class ParameterRange(object):
@@ -27,7 +29,12 @@ class ParameterRange(object):
2729

2830
__all_types__ = ("Continuous", "Categorical", "Integer")
2931

30-
def __init__(self, min_value, max_value, scaling_type="Auto"):
32+
def __init__(
33+
self,
34+
min_value: Union[int, float, PipelineVariable],
35+
max_value: Union[int, float, PipelineVariable],
36+
scaling_type: Union[str, PipelineVariable] = "Auto",
37+
):
3138
"""Initialize a parameter range.
3239
3340
Args:

0 commit comments

Comments
 (0)