|
18 | 18 | import os
|
19 | 19 | import uuid
|
20 | 20 | from abc import ABCMeta, abstractmethod
|
21 |
| -from typing import Any, Dict |
| 21 | +from typing import Any, Dict, Union, Optional, List |
22 | 22 |
|
23 | 23 | from six import string_types, with_metaclass
|
24 | 24 | from six.moves.urllib.parse import urlparse
|
|
36 | 36 | TensorBoardOutputConfig,
|
37 | 37 | get_default_profiler_rule,
|
38 | 38 | get_rule_container_image_uri,
|
| 39 | + RuleBase, |
39 | 40 | )
|
40 | 41 | from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs
|
41 | 42 | from sagemaker.fw_utils import (
|
|
46 | 47 | tar_and_upload_dir,
|
47 | 48 | validate_source_dir,
|
48 | 49 | )
|
49 |
| -from sagemaker.inputs import TrainingInput |
| 50 | +from sagemaker.inputs import TrainingInput, FileSystemInput |
50 | 51 | from sagemaker.job import _Job
|
51 | 52 | from sagemaker.jumpstart.utils import (
|
52 | 53 | add_jumpstart_tags,
|
|
75 | 76 | name_from_base,
|
76 | 77 | )
|
77 | 78 | from sagemaker.workflow import is_pipeline_variable
|
| 79 | +from sagemaker.workflow.entities import PipelineVariable |
78 | 80 | from sagemaker.workflow.pipeline_context import (
|
79 | 81 | PipelineSession,
|
80 | 82 | runnable_by_pipeline,
|
@@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
|
105 | 107 |
|
106 | 108 | def __init__(
|
107 | 109 | self,
|
108 |
| - role, |
109 |
| - instance_count=None, |
110 |
| - instance_type=None, |
111 |
| - volume_size=30, |
112 |
| - volume_kms_key=None, |
113 |
| - max_run=24 * 60 * 60, |
114 |
| - input_mode="File", |
115 |
| - output_path=None, |
116 |
| - output_kms_key=None, |
117 |
| - base_job_name=None, |
118 |
| - sagemaker_session=None, |
119 |
| - tags=None, |
120 |
| - subnets=None, |
121 |
| - security_group_ids=None, |
122 |
| - model_uri=None, |
123 |
| - model_channel_name="model", |
124 |
| - metric_definitions=None, |
125 |
| - encrypt_inter_container_traffic=False, |
126 |
| - use_spot_instances=False, |
127 |
| - max_wait=None, |
128 |
| - checkpoint_s3_uri=None, |
129 |
| - checkpoint_local_path=None, |
130 |
| - rules=None, |
131 |
| - debugger_hook_config=None, |
132 |
| - tensorboard_output_config=None, |
133 |
| - enable_sagemaker_metrics=None, |
134 |
| - enable_network_isolation=False, |
135 |
| - profiler_config=None, |
136 |
| - disable_profiler=False, |
137 |
| - environment=None, |
138 |
| - max_retry_attempts=None, |
139 |
| - source_dir=None, |
140 |
| - git_config=None, |
141 |
| - hyperparameters=None, |
142 |
| - container_log_level=logging.INFO, |
143 |
| - code_location=None, |
144 |
| - entry_point=None, |
145 |
| - dependencies=None, |
| 110 | + role: str, |
| 111 | + instance_count: Optional[int] = None, |
| 112 | + instance_type: Optional[Union[str, PipelineVariable]] = None, |
| 113 | + volume_size: Union[int, PipelineVariable] = 30, |
| 114 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 115 | + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, |
| 116 | + input_mode: Union[str, PipelineVariable] = "File", |
| 117 | + output_path: Optional[Union[str, PipelineVariable]] = None, |
| 118 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 119 | + base_job_name: Optional[str] = None, |
| 120 | + sagemaker_session: Optional[Session] = None, |
| 121 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 122 | + subnets: Optional[List[Union[str, PipelineVariable]]] = None, |
| 123 | + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, |
| 124 | + model_uri: Optional[str] = None, |
| 125 | + model_channel_name: Union[str, PipelineVariable] = "model", |
| 126 | + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 127 | + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, |
| 128 | + use_spot_instances: Union[bool, PipelineVariable] = False, |
| 129 | + max_wait: Union[int, PipelineVariable] = None, |
| 130 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 131 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 132 | + rules: Optional[List[RuleBase]] = None, |
| 133 | + debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, |
| 134 | + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, |
| 135 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
| 136 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 137 | + profiler_config: Optional[ProfilerConfig] = None, |
| 138 | + disable_profiler: bool = False, |
| 139 | + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 140 | + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
| 141 | + source_dir: Optional[str] = None, |
| 142 | + git_config: Optional[Dict[str, str]] = None, |
| 143 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 144 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 145 | + code_location: Optional[str] = None, |
| 146 | + entry_point: Optional[str] = None, |
| 147 | + dependencies: Optional[List[Union[str]]] = None, |
146 | 148 | **kwargs,
|
147 | 149 | ):
|
148 | 150 | """Initialize an ``EstimatorBase`` instance.
|
@@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self):
|
922 | 924 | return None
|
923 | 925 |
|
924 | 926 | @runnable_by_pipeline
|
925 |
| - def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None): |
| 927 | + def fit( |
| 928 | + self, |
| 929 | + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, |
| 930 | + wait: bool = True, |
| 931 | + logs: str = "All", |
| 932 | + job_name: Optional[str] = None, |
| 933 | + experiment_config: Optional[Dict[str, str]] = None, |
| 934 | + ): |
926 | 935 | """Train a model using the input training dataset.
|
927 | 936 |
|
928 | 937 | The API calls the Amazon SageMaker CreateTrainingJob API to start
|
@@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
|
1870 | 1879 | )
|
1871 | 1880 | train_args["input_mode"] = inputs.config["InputMode"]
|
1872 | 1881 |
|
| 1882 | + # enable_network_isolation may be a pipeline variable place holder object |
| 1883 | + # which is parsed in execution time |
1873 | 1884 | if estimator.enable_network_isolation():
|
1874 |
| - train_args["enable_network_isolation"] = True |
| 1885 | + train_args["enable_network_isolation"] = estimator.enable_network_isolation() |
1875 | 1886 |
|
1876 | 1887 | if estimator.max_retry_attempts is not None:
|
1877 | 1888 | train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
|
1878 | 1889 | else:
|
1879 | 1890 | train_args["retry_strategy"] = None
|
1880 | 1891 |
|
| 1892 | + # encrypt_inter_container_traffic may be a pipeline variable place holder object |
| 1893 | + # which is parsed in execution time |
1881 | 1894 | if estimator.encrypt_inter_container_traffic:
|
1882 |
| - train_args["encrypt_inter_container_traffic"] = True |
| 1895 | + train_args[ |
| 1896 | + "encrypt_inter_container_traffic" |
| 1897 | + ] = estimator.encrypt_inter_container_traffic |
1883 | 1898 |
|
1884 | 1899 | if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
|
1885 | 1900 | train_args["algorithm_arn"] = estimator.algorithm_arn
|
@@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase):
|
2025 | 2040 |
|
2026 | 2041 | def __init__(
|
2027 | 2042 | self,
|
2028 |
| - image_uri, |
2029 |
| - role, |
2030 |
| - instance_count=None, |
2031 |
| - instance_type=None, |
2032 |
| - volume_size=30, |
2033 |
| - volume_kms_key=None, |
2034 |
| - max_run=24 * 60 * 60, |
2035 |
| - input_mode="File", |
2036 |
| - output_path=None, |
2037 |
| - output_kms_key=None, |
2038 |
| - base_job_name=None, |
2039 |
| - sagemaker_session=None, |
2040 |
| - hyperparameters=None, |
2041 |
| - tags=None, |
2042 |
| - subnets=None, |
2043 |
| - security_group_ids=None, |
2044 |
| - model_uri=None, |
2045 |
| - model_channel_name="model", |
2046 |
| - metric_definitions=None, |
2047 |
| - encrypt_inter_container_traffic=False, |
2048 |
| - use_spot_instances=False, |
2049 |
| - max_wait=None, |
2050 |
| - checkpoint_s3_uri=None, |
2051 |
| - checkpoint_local_path=None, |
2052 |
| - enable_network_isolation=False, |
2053 |
| - rules=None, |
2054 |
| - debugger_hook_config=None, |
2055 |
| - tensorboard_output_config=None, |
2056 |
| - enable_sagemaker_metrics=None, |
2057 |
| - profiler_config=None, |
2058 |
| - disable_profiler=False, |
2059 |
| - environment=None, |
2060 |
| - max_retry_attempts=None, |
2061 |
| - source_dir=None, |
2062 |
| - git_config=None, |
2063 |
| - container_log_level=logging.INFO, |
2064 |
| - code_location=None, |
2065 |
| - entry_point=None, |
2066 |
| - dependencies=None, |
| 2043 | + image_uri: Union[str, PipelineVariable], |
| 2044 | + role: str, |
| 2045 | + instance_count: Optional[Union[int, PipelineVariable]] = None, |
| 2046 | + instance_type: Optional[Union[str, PipelineVariable]] = None, |
| 2047 | + volume_size: Union[int, PipelineVariable] = 30, |
| 2048 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 2049 | + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, |
| 2050 | + input_mode: Union[str, PipelineVariable] = "File", |
| 2051 | + output_path: Optional[Union[str, PipelineVariable]] = None, |
| 2052 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 2053 | + base_job_name: Optional[str] = None, |
| 2054 | + sagemaker_session: Optional[Session] = None, |
| 2055 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2056 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 2057 | + subnets: Optional[List[Union[str, PipelineVariable]]] = None, |
| 2058 | + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, |
| 2059 | + model_uri: Optional[str] = None, |
| 2060 | + model_channel_name: Union[str, PipelineVariable] = "model", |
| 2061 | + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 2062 | + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, |
| 2063 | + use_spot_instances: Union[bool, PipelineVariable] = False, |
| 2064 | + max_wait: Union[int, PipelineVariable] = None, |
| 2065 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2066 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 2067 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 2068 | + rules: Optional[List[RuleBase]] = None, |
| 2069 | + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, |
| 2070 | + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, |
| 2071 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
| 2072 | + profiler_config: Optional[ProfilerConfig] = None, |
| 2073 | + disable_profiler: bool = False, |
| 2074 | + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2075 | + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
| 2076 | + source_dir: Optional[str] = None, |
| 2077 | + git_config: Optional[Dict[str, str]] = None, |
| 2078 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 2079 | + code_location: Optional[str] = None, |
| 2080 | + entry_point: Optional[str] = None, |
| 2081 | + dependencies: Optional[List[str]] = None, |
2067 | 2082 | **kwargs,
|
2068 | 2083 | ):
|
2069 | 2084 | """Initialize an ``Estimator`` instance.
|
@@ -2488,18 +2503,18 @@ class Framework(EstimatorBase):
|
2488 | 2503 |
|
2489 | 2504 | def __init__(
|
2490 | 2505 | self,
|
2491 |
| - entry_point, |
2492 |
| - source_dir=None, |
2493 |
| - hyperparameters=None, |
2494 |
| - container_log_level=logging.INFO, |
2495 |
| - code_location=None, |
2496 |
| - image_uri=None, |
2497 |
| - dependencies=None, |
2498 |
| - enable_network_isolation=False, |
2499 |
| - git_config=None, |
2500 |
| - checkpoint_s3_uri=None, |
2501 |
| - checkpoint_local_path=None, |
2502 |
| - enable_sagemaker_metrics=None, |
| 2506 | + entry_point: str, |
| 2507 | + source_dir: Optional[str] = None, |
| 2508 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2509 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 2510 | + code_location: Optional[str] = None, |
| 2511 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2512 | + dependencies: Optional[List[str]] = None, |
| 2513 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 2514 | + git_config: Optional[Dict[str, str]] = None, |
| 2515 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2516 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 2517 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
2503 | 2518 | **kwargs,
|
2504 | 2519 | ):
|
2505 | 2520 | """Base class initializer.
|
|
0 commit comments