Skip to content

change: Add PipelineVariable annotation in model base class and tensorflow estimator #3190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import logging
import tempfile
from typing import Union

from six.moves.urllib.parse import urlparse

Expand All @@ -27,6 +28,7 @@
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.utils import sagemaker_timestamp
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -304,7 +306,12 @@ class RecordSet(object):
"""Placeholder docstring"""

def __init__(
self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train"
self,
s3_data: Union[str, PipelineVariable],
num_records: int,
feature_dim: int,
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
channel: Union[str, PipelineVariable] = "train",
):
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.

Expand Down
45 changes: 28 additions & 17 deletions src/sagemaker/debugger/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@

from abc import ABC

from typing import Union, Optional, List, Dict

import attr

import smdebug_rulesconfig as rule_configs

from sagemaker import image_uris
from sagemaker.utils import build_dict
from sagemaker.workflow.entities import PipelineVariable

framework_name = "debugger"
DEBUGGER_FLAG = "USE_SMDEBUG"
Expand Down Expand Up @@ -311,17 +314,17 @@ def sagemaker(
@classmethod
def custom(
cls,
name,
image_uri,
instance_type,
volume_size_in_gb,
source=None,
rule_to_invoke=None,
container_local_output_path=None,
s3_output_path=None,
other_trials_s3_input_paths=None,
rule_parameters=None,
collections_to_save=None,
name: str,
image_uri: Union[str, PipelineVariable],
instance_type: Union[str, PipelineVariable],
volume_size_in_gb: Union[int, PipelineVariable],
source: Optional[str] = None,
rule_to_invoke: Optional[Union[str, PipelineVariable]] = None,
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None,
rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
collections_to_save: Optional[List["CollectionConfig"]] = None,
actions=None,
):
"""Initialize a ``Rule`` object for a *custom* debugging rule.
Expand Down Expand Up @@ -610,10 +613,10 @@ class DebuggerHookConfig(object):

def __init__(
self,
s3_output_path=None,
container_local_output_path=None,
hook_parameters=None,
collection_configs=None,
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
collection_configs: Optional[List["CollectionConfig"]] = None,
):
"""Initialize the DebuggerHookConfig instance.

Expand Down Expand Up @@ -679,7 +682,11 @@ def _to_request_dict(self):
class TensorBoardOutputConfig(object):
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""

def __init__(self, s3_output_path, container_local_output_path=None):
def __init__(
self,
s3_output_path: Union[str, PipelineVariable],
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize the TensorBoardOutputConfig instance.

Args:
Expand Down Expand Up @@ -708,7 +715,11 @@ def _to_request_dict(self):
class CollectionConfig(object):
"""Creates tensor collections for SageMaker Debugger."""

def __init__(self, name, parameters=None):
def __init__(
self,
name: Union[str, PipelineVariable],
parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
):
"""Constructor for collection configuration.

Args:
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/debugger/profiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
from __future__ import absolute_import

from typing import Optional, Union

from sagemaker.debugger.framework_profile import FrameworkProfile
from sagemaker.workflow.entities import PipelineVariable


class ProfilerConfig(object):
Expand All @@ -26,9 +29,9 @@ class ProfilerConfig(object):

def __init__(
self,
s3_output_path=None,
system_monitor_interval_millis=None,
framework_profile_params=None,
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
framework_profile_params: Optional[FrameworkProfile] = None,
):
"""Initialize a ``ProfilerConfig`` instance.

Expand Down
22 changes: 13 additions & 9 deletions src/sagemaker/drift_check_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@
"""This file contains code related to drift check baselines"""
from __future__ import absolute_import

from typing import Optional

from sagemaker.model_metrics import MetricsSource, FileSource


class DriftCheckBaselines(object):
"""Accepts drift check baselines parameters for conversion to request dict."""

def __init__(
self,
model_statistics=None,
model_constraints=None,
model_data_statistics=None,
model_data_constraints=None,
bias_config_file=None,
bias_pre_training_constraints=None,
bias_post_training_constraints=None,
explainability_constraints=None,
explainability_config_file=None,
model_statistics: Optional[MetricsSource] = None,
model_constraints: Optional[MetricsSource] = None,
model_data_statistics: Optional[MetricsSource] = None,
model_data_constraints: Optional[MetricsSource] = None,
bias_config_file: Optional[FileSource] = None,
bias_pre_training_constraints: Optional[MetricsSource] = None,
bias_post_training_constraints: Optional[MetricsSource] = None,
explainability_constraints: Optional[MetricsSource] = None,
explainability_config_file: Optional[FileSource] = None,
):
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.

Expand Down
17 changes: 14 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
validate_source_code_input_against_pipeline_variables,
)
from sagemaker.inputs import TrainingInput, FileSystemInput
from sagemaker.instance_group import InstanceGroup
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import (
add_jumpstart_tags,
Expand Down Expand Up @@ -149,7 +150,7 @@ def __init__(
code_location: Optional[str] = None,
entry_point: Optional[Union[str, PipelineVariable]] = None,
dependencies: Optional[List[Union[str]]] = None,
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
instance_groups: Optional[List[InstanceGroup]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -1580,6 +1581,8 @@ def _get_instance_type(self):

for instance_group in self.instance_groups:
instance_type = instance_group.instance_type
if is_pipeline_variable(instance_type):
continue
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)

if match:
Expand Down Expand Up @@ -2179,7 +2182,7 @@ def __init__(
code_location: Optional[str] = None,
entry_point: Optional[Union[str, PipelineVariable]] = None,
dependencies: Optional[List[str]] = None,
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
instance_groups: Optional[List[InstanceGroup]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -2874,7 +2877,15 @@ def _validate_and_set_debugger_configs(self):
# Disable debugger if checkpointing is enabled by the customer
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
if self.instance_count > 1 or (
if is_pipeline_variable(self.instance_count):
logger.warning(
"SMDebug does not currently support distributed training jobs "
"with checkpointing enabled. Therefore, to allow parameterized "
"instance_count and allow to change it to any values in execution time, "
"the debugger_hook_config is disabled."
)
self.debugger_hook_config = False
elif self.instance_count > 1 or (
hasattr(self, "distribution")
and self.distribution is not None # pylint: disable=no-member
):
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/huggingface/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
"""Configuration for the SageMaker Training Compiler."""
from __future__ import absolute_import
import logging
from typing import Union

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)

Expand All @@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig):

def __init__(
self,
enabled=True,
debug=False,
enabled: Union[bool, PipelineVariable] = True,
debug: Union[bool, PipelineVariable] = False,
):
"""This class initializes a ``TrainingCompilerConfig`` instance.

Expand Down
25 changes: 14 additions & 11 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
from __future__ import absolute_import, print_function

from typing import Union, Optional, List
import attr

from sagemaker.workflow.entities import PipelineVariable

FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]

Expand All @@ -29,17 +32,17 @@ class TrainingInput(object):

def __init__(
self,
s3_data,
distribution=None,
compression=None,
content_type=None,
record_wrapping=None,
s3_data_type="S3Prefix",
instance_groups=None,
input_mode=None,
attribute_names=None,
target_attribute_name=None,
shuffle_config=None,
s3_data: Union[str, PipelineVariable],
distribution: Optional[Union[str, PipelineVariable]] = None,
compression: Optional[Union[str, PipelineVariable]] = None,
content_type: Optional[Union[str, PipelineVariable]] = None,
record_wrapping: Optional[Union[str, PipelineVariable]] = None,
s3_data_type: Union[str, PipelineVariable] = "S3Prefix",
instance_groups: Optional[List[Union[str, PipelineVariable]]] = None,
input_mode: Optional[Union[str, PipelineVariable]] = None,
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
shuffle_config: Optional["ShuffleConfig"] = None,
):
r"""Create a definition for input data used by an SageMaker training job.

Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/metadata_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
"""This file contains code related to metadata properties."""
from __future__ import absolute_import

from typing import Optional, Union

from sagemaker.workflow.entities import PipelineVariable


class MetadataProperties(object):
"""Accepts metadata properties parameters for conversion to request dict."""

def __init__(
self,
commit_id=None,
repository=None,
generated_by=None,
project_id=None,
commit_id: Optional[Union[str, PipelineVariable]] = None,
repository: Optional[Union[str, PipelineVariable]] = None,
generated_by: Optional[Union[str, PipelineVariable]] = None,
project_id: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.

Expand Down
Loading