Skip to content

Commit 2a810b2

Browse files
author
Dewen Qi
committed
change: Add Pipeline annotation in model base class and tensorflow estimator
Model annotate update change: Add PipelineVariable annotation to composite argument of training go with model base and tf
1 parent bbb715d commit 2a810b2

14 files changed

+202
-133
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19+
from typing import Union
1920

2021
from six.moves.urllib.parse import urlparse
2122

@@ -27,6 +28,7 @@
2728
from sagemaker.estimator import EstimatorBase, _TrainingJob
2829
from sagemaker.inputs import FileSystemInput, TrainingInput
2930
from sagemaker.utils import sagemaker_timestamp
31+
from sagemaker.workflow.entities import PipelineVariable
3032
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3133

3234
logger = logging.getLogger(__name__)
@@ -304,7 +306,12 @@ class RecordSet(object):
304306
"""Placeholder docstring"""
305307

306308
def __init__(
307-
self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train"
309+
self,
310+
s3_data: Union[str, PipelineVariable],
311+
num_records: int,
312+
feature_dim: int,
313+
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
314+
channel: Union[str, PipelineVariable] = "train",
308315
):
309316
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
310317

src/sagemaker/debugger/debugger.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424

2525
from abc import ABC
2626

27+
from typing import Union, Optional, List, Dict
28+
2729
import attr
2830

2931
import smdebug_rulesconfig as rule_configs
3032

3133
from sagemaker import image_uris
3234
from sagemaker.utils import build_dict
35+
from sagemaker.workflow.entities import PipelineVariable
3336

3437
framework_name = "debugger"
3538
DEBUGGER_FLAG = "USE_SMDEBUG"
@@ -311,17 +314,17 @@ def sagemaker(
311314
@classmethod
312315
def custom(
313316
cls,
314-
name,
315-
image_uri,
316-
instance_type,
317-
volume_size_in_gb,
318-
source=None,
319-
rule_to_invoke=None,
320-
container_local_output_path=None,
321-
s3_output_path=None,
322-
other_trials_s3_input_paths=None,
323-
rule_parameters=None,
324-
collections_to_save=None,
317+
name: str,
318+
image_uri: Union[str, PipelineVariable],
319+
instance_type: Union[str, PipelineVariable],
320+
volume_size_in_gb: Union[int, PipelineVariable],
321+
source: Optional[str] = None,
322+
rule_to_invoke: Optional[Union[str, PipelineVariable]] = None,
323+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
324+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
325+
other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None,
326+
rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
327+
collections_to_save: Optional[List["CollectionConfig"]] = None,
325328
actions=None,
326329
):
327330
"""Initialize a ``Rule`` object for a *custom* debugging rule.
@@ -610,10 +613,10 @@ class DebuggerHookConfig(object):
610613

611614
def __init__(
612615
self,
613-
s3_output_path=None,
614-
container_local_output_path=None,
615-
hook_parameters=None,
616-
collection_configs=None,
616+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
617+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
618+
hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
619+
collection_configs: Optional[List["CollectionConfig"]] = None,
617620
):
618621
"""Initialize the DebuggerHookConfig instance.
619622
@@ -679,7 +682,11 @@ def _to_request_dict(self):
679682
class TensorBoardOutputConfig(object):
680683
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
681684

682-
def __init__(self, s3_output_path, container_local_output_path=None):
685+
def __init__(
686+
self,
687+
s3_output_path: Union[str, PipelineVariable],
688+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
689+
):
683690
"""Initialize the TensorBoardOutputConfig instance.
684691
685692
Args:
@@ -708,7 +715,11 @@ def _to_request_dict(self):
708715
class CollectionConfig(object):
709716
"""Creates tensor collections for SageMaker Debugger."""
710717

711-
def __init__(self, name, parameters=None):
718+
def __init__(
719+
self,
720+
name: Union[str, PipelineVariable],
721+
parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
722+
):
712723
"""Constructor for collection configuration.
713724
714725
Args:

src/sagemaker/debugger/profiler_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
1618
from sagemaker.debugger.framework_profile import FrameworkProfile
19+
from sagemaker.workflow.entities import PipelineVariable
1720

1821

1922
class ProfilerConfig(object):
@@ -26,9 +29,9 @@ class ProfilerConfig(object):
2629

2730
def __init__(
2831
self,
29-
s3_output_path=None,
30-
system_monitor_interval_millis=None,
31-
framework_profile_params=None,
32+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
33+
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
34+
framework_profile_params: Optional[FrameworkProfile] = None,
3235
):
3336
"""Initialize a ``ProfilerConfig`` instance.
3437

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
validate_source_code_input_against_pipeline_variables,
5151
)
5252
from sagemaker.inputs import TrainingInput, FileSystemInput
53+
from sagemaker.instance_group import InstanceGroup
5354
from sagemaker.job import _Job
5455
from sagemaker.jumpstart.utils import (
5556
add_jumpstart_tags,
@@ -149,7 +150,7 @@ def __init__(
149150
code_location: Optional[str] = None,
150151
entry_point: Optional[Union[str, PipelineVariable]] = None,
151152
dependencies: Optional[List[Union[str]]] = None,
152-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
153+
instance_groups: Optional[List[InstanceGroup]] = None,
153154
**kwargs,
154155
):
155156
"""Initialize an ``EstimatorBase`` instance.
@@ -1580,6 +1581,8 @@ def _get_instance_type(self):
15801581

15811582
for instance_group in self.instance_groups:
15821583
instance_type = instance_group.instance_type
1584+
if is_pipeline_variable(instance_type):
1585+
continue
15831586
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
15841587

15851588
if match:
@@ -2179,7 +2182,7 @@ def __init__(
21792182
code_location: Optional[str] = None,
21802183
entry_point: Optional[Union[str, PipelineVariable]] = None,
21812184
dependencies: Optional[List[str]] = None,
2182-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
2185+
instance_groups: Optional[List[InstanceGroup]] = None,
21832186
**kwargs,
21842187
):
21852188
"""Initialize an ``Estimator`` instance.
@@ -2874,7 +2877,15 @@ def _validate_and_set_debugger_configs(self):
28742877
# Disable debugger if checkpointing is enabled by the customer
28752878
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
28762879
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2877-
if self.instance_count > 1 or (
2880+
if is_pipeline_variable(self.instance_count):
2881+
logger.warning(
2882+
"SMDebug does not currently support distributed training jobs "
2883+
"with checkpointing enabled. Therefore, to allow parameterized "
2884+
"instance_count and allow to change it to any values in execution time, "
2885+
"the debugger_hook_config is disabled."
2886+
)
2887+
self.debugger_hook_config = False
2888+
elif self.instance_count > 1 or (
28782889
hasattr(self, "distribution")
28792890
and self.distribution is not None # pylint: disable=no-member
28802891
):

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,14 @@ def validate_distribution_instance(sagemaker_session, distribution, instance_typ
871871
# Strategy modelparallel is not enabled
872872
return
873873

874+
if is_pipeline_variable(instance_type):
875+
logger.warning(
876+
"instance_type is a pipeline variable, which is only interpreted in "
877+
"pipeline execution time. As modelparallel only runs on GPU-enabled "
878+
"instances, in execution time, the specified instance type has to support GPU."
879+
)
880+
return
881+
874882
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
875883
InstanceTypes=[f"{instance_type}"]
876884
)

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Configuration for the SageMaker Training Compiler."""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Union
1617

1718
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
19+
from sagemaker.workflow.entities import PipelineVariable
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig):
2628

2729
def __init__(
2830
self,
29-
enabled=True,
30-
debug=False,
31+
enabled: Union[bool, PipelineVariable] = True,
32+
debug: Union[bool, PipelineVariable] = False,
3133
):
3234
"""This class initializes a ``TrainingCompilerConfig`` instance.
3335

src/sagemaker/inputs.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16+
from typing import Union, Optional, List
1617
import attr
1718

19+
from sagemaker.workflow.entities import PipelineVariable
20+
1821
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
1922
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2023

@@ -29,17 +32,17 @@ class TrainingInput(object):
2932

3033
def __init__(
3134
self,
32-
s3_data,
33-
distribution=None,
34-
compression=None,
35-
content_type=None,
36-
record_wrapping=None,
37-
s3_data_type="S3Prefix",
38-
instance_groups=None,
39-
input_mode=None,
40-
attribute_names=None,
41-
target_attribute_name=None,
42-
shuffle_config=None,
35+
s3_data: Union[str, PipelineVariable],
36+
distribution: Optional[Union[str, PipelineVariable]] = None,
37+
compression: Optional[Union[str, PipelineVariable]] = None,
38+
content_type: Optional[Union[str, PipelineVariable]] = None,
39+
record_wrapping: Optional[Union[str, PipelineVariable]] = None,
40+
s3_data_type: Union[str, PipelineVariable] = "S3Prefix",
41+
instance_groups: Optional[List[Union[str, PipelineVariable]]] = None,
42+
input_mode: Optional[Union[str, PipelineVariable]] = None,
43+
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
44+
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
45+
shuffle_config: Optional["ShuffleConfig"] = None,
4346
):
4447
r"""Create a definition for input data used by an SageMaker training job.
4548

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

0 commit comments

Comments
 (0)