Skip to content

change: Add PipelineVariable annotation in estimator, processing, tuner, transformer base classes #3182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 109 additions & 94 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import uuid
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Union, Optional, List

from six import string_types, with_metaclass
from six.moves.urllib.parse import urlparse
Expand All @@ -36,6 +36,7 @@
TensorBoardOutputConfig,
get_default_profiler_rule,
get_rule_container_image_uri,
RuleBase,
)
from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs
from sagemaker.fw_utils import (
Expand All @@ -46,7 +47,7 @@
tar_and_upload_dir,
validate_source_dir,
)
from sagemaker.inputs import TrainingInput
from sagemaker.inputs import TrainingInput, FileSystemInput
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import (
add_jumpstart_tags,
Expand Down Expand Up @@ -75,6 +76,7 @@
name_from_base,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import (
PipelineSession,
runnable_by_pipeline,
Expand Down Expand Up @@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man

def __init__(
self,
role,
instance_count=None,
instance_type=None,
volume_size=30,
volume_kms_key=None,
max_run=24 * 60 * 60,
input_mode="File",
output_path=None,
output_kms_key=None,
base_job_name=None,
sagemaker_session=None,
tags=None,
subnets=None,
security_group_ids=None,
model_uri=None,
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
use_spot_instances=False,
max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
rules=None,
debugger_hook_config=None,
tensorboard_output_config=None,
enable_sagemaker_metrics=None,
enable_network_isolation=False,
profiler_config=None,
disable_profiler=False,
environment=None,
max_retry_attempts=None,
source_dir=None,
git_config=None,
hyperparameters=None,
container_log_level=logging.INFO,
code_location=None,
entry_point=None,
dependencies=None,
role: str,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
volume_size: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
input_mode: Union[str, PipelineVariable] = "File",
output_path: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = False,
profiler_config: Optional[ProfilerConfig] = None,
disable_profiler: bool = False,
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
source_dir: Optional[str] = None,
git_config: Optional[Dict[str, str]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
code_location: Optional[str] = None,
entry_point: Optional[str] = None,
dependencies: Optional[List[Union[str]]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self):
return None

@runnable_by_pipeline
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
def fit(
self,
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
wait: bool = True,
logs: str = "All",
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
):
"""Train a model using the input training dataset.

The API calls the Amazon SageMaker CreateTrainingJob API to start
Expand Down Expand Up @@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
)
train_args["input_mode"] = inputs.config["InputMode"]

# enable_network_isolation may be a pipeline variable place holder object
# which is parsed in execution time
if estimator.enable_network_isolation():
train_args["enable_network_isolation"] = True
train_args["enable_network_isolation"] = estimator.enable_network_isolation()

if estimator.max_retry_attempts is not None:
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
else:
train_args["retry_strategy"] = None

# encrypt_inter_container_traffic may be a pipeline variable place holder object
# which is parsed in execution time
if estimator.encrypt_inter_container_traffic:
train_args["encrypt_inter_container_traffic"] = True
train_args[
"encrypt_inter_container_traffic"
] = estimator.encrypt_inter_container_traffic

if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
train_args["algorithm_arn"] = estimator.algorithm_arn
Expand Down Expand Up @@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase):

def __init__(
self,
image_uri,
role,
instance_count=None,
instance_type=None,
volume_size=30,
volume_kms_key=None,
max_run=24 * 60 * 60,
input_mode="File",
output_path=None,
output_kms_key=None,
base_job_name=None,
sagemaker_session=None,
hyperparameters=None,
tags=None,
subnets=None,
security_group_ids=None,
model_uri=None,
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
use_spot_instances=False,
max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
enable_network_isolation=False,
rules=None,
debugger_hook_config=None,
tensorboard_output_config=None,
enable_sagemaker_metrics=None,
profiler_config=None,
disable_profiler=False,
environment=None,
max_retry_attempts=None,
source_dir=None,
git_config=None,
container_log_level=logging.INFO,
code_location=None,
entry_point=None,
dependencies=None,
image_uri: Union[str, PipelineVariable],
role: str,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
volume_size: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
input_mode: Union[str, PipelineVariable] = "File",
output_path: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = False,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
disable_profiler: bool = False,
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
source_dir: Optional[str] = None,
git_config: Optional[Dict[str, str]] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
code_location: Optional[str] = None,
entry_point: Optional[str] = None,
dependencies: Optional[List[str]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -2488,18 +2503,18 @@ class Framework(EstimatorBase):

def __init__(
self,
entry_point,
source_dir=None,
hyperparameters=None,
container_log_level=logging.INFO,
code_location=None,
image_uri=None,
dependencies=None,
enable_network_isolation=False,
git_config=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
enable_sagemaker_metrics=None,
entry_point: str,
source_dir: Optional[str] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
code_location: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
dependencies: Optional[List[str]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = False,
git_config: Optional[Dict[str, str]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Base class initializer.
Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"""
from __future__ import absolute_import

from typing import Union, Optional, List

from sagemaker.workflow.entities import PipelineVariable


class NetworkConfig(object):
"""Accepts network configuration parameters for conversion to request dict.
Expand All @@ -25,10 +29,10 @@ class NetworkConfig(object):

def __init__(
self,
enable_network_isolation=False,
security_group_ids=None,
subnets=None,
encrypt_inter_container_traffic=None,
enable_network_isolation: Union[bool, PipelineVariable] = False,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initialize a ``NetworkConfig`` instance.

Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from __future__ import absolute_import

import json
from typing import Union

from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable


class ParameterRange(object):
Expand All @@ -27,7 +29,12 @@ class ParameterRange(object):

__all_types__ = ("Continuous", "Categorical", "Integer")

def __init__(self, min_value, max_value, scaling_type="Auto"):
def __init__(
self,
min_value: Union[int, float, PipelineVariable],
max_value: Union[int, float, PipelineVariable],
scaling_type: Union[str, PipelineVariable] = "Auto",
):
"""Initialize a parameter range.

Args:
Expand Down
Loading