Skip to content

Commit 8d7dd32

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
change: Add Pipeline annotation in model base class and tensorflow estimator (#3190)
Model annotate update change: Add PipelineVariable annotation to composite argument of training go with model base and tf Co-authored-by: Dewen Qi <[email protected]>
1 parent 02fdf1b commit 8d7dd32

File tree

13 files changed

+194
-133
lines changed

13 files changed

+194
-133
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19+
from typing import Union
1920

2021
from six.moves.urllib.parse import urlparse
2122

@@ -27,6 +28,7 @@
2728
from sagemaker.estimator import EstimatorBase, _TrainingJob
2829
from sagemaker.inputs import FileSystemInput, TrainingInput
2930
from sagemaker.utils import sagemaker_timestamp
31+
from sagemaker.workflow.entities import PipelineVariable
3032
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3133

3234
logger = logging.getLogger(__name__)
@@ -304,7 +306,12 @@ class RecordSet(object):
304306
"""Placeholder docstring"""
305307

306308
def __init__(
307-
self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train"
309+
self,
310+
s3_data: Union[str, PipelineVariable],
311+
num_records: int,
312+
feature_dim: int,
313+
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
314+
channel: Union[str, PipelineVariable] = "train",
308315
):
309316
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
310317

src/sagemaker/debugger/debugger.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424

2525
from abc import ABC
2626

27+
from typing import Union, Optional, List, Dict
28+
2729
import attr
2830

2931
import smdebug_rulesconfig as rule_configs
3032

3133
from sagemaker import image_uris
3234
from sagemaker.utils import build_dict
35+
from sagemaker.workflow.entities import PipelineVariable
3336

3437
framework_name = "debugger"
3538
DEBUGGER_FLAG = "USE_SMDEBUG"
@@ -311,17 +314,17 @@ def sagemaker(
311314
@classmethod
312315
def custom(
313316
cls,
314-
name,
315-
image_uri,
316-
instance_type,
317-
volume_size_in_gb,
318-
source=None,
319-
rule_to_invoke=None,
320-
container_local_output_path=None,
321-
s3_output_path=None,
322-
other_trials_s3_input_paths=None,
323-
rule_parameters=None,
324-
collections_to_save=None,
317+
name: str,
318+
image_uri: Union[str, PipelineVariable],
319+
instance_type: Union[str, PipelineVariable],
320+
volume_size_in_gb: Union[int, PipelineVariable],
321+
source: Optional[str] = None,
322+
rule_to_invoke: Optional[Union[str, PipelineVariable]] = None,
323+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
324+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
325+
other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None,
326+
rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
327+
collections_to_save: Optional[List["CollectionConfig"]] = None,
325328
actions=None,
326329
):
327330
"""Initialize a ``Rule`` object for a *custom* debugging rule.
@@ -610,10 +613,10 @@ class DebuggerHookConfig(object):
610613

611614
def __init__(
612615
self,
613-
s3_output_path=None,
614-
container_local_output_path=None,
615-
hook_parameters=None,
616-
collection_configs=None,
616+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
617+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
618+
hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
619+
collection_configs: Optional[List["CollectionConfig"]] = None,
617620
):
618621
"""Initialize the DebuggerHookConfig instance.
619622
@@ -679,7 +682,11 @@ def _to_request_dict(self):
679682
class TensorBoardOutputConfig(object):
680683
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
681684

682-
def __init__(self, s3_output_path, container_local_output_path=None):
685+
def __init__(
686+
self,
687+
s3_output_path: Union[str, PipelineVariable],
688+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
689+
):
683690
"""Initialize the TensorBoardOutputConfig instance.
684691
685692
Args:
@@ -708,7 +715,11 @@ def _to_request_dict(self):
708715
class CollectionConfig(object):
709716
"""Creates tensor collections for SageMaker Debugger."""
710717

711-
def __init__(self, name, parameters=None):
718+
def __init__(
719+
self,
720+
name: Union[str, PipelineVariable],
721+
parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
722+
):
712723
"""Constructor for collection configuration.
713724
714725
Args:

src/sagemaker/debugger/profiler_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
1618
from sagemaker.debugger.framework_profile import FrameworkProfile
19+
from sagemaker.workflow.entities import PipelineVariable
1720

1821

1922
class ProfilerConfig(object):
@@ -26,9 +29,9 @@ class ProfilerConfig(object):
2629

2730
def __init__(
2831
self,
29-
s3_output_path=None,
30-
system_monitor_interval_millis=None,
31-
framework_profile_params=None,
32+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
33+
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
34+
framework_profile_params: Optional[FrameworkProfile] = None,
3235
):
3336
"""Initialize a ``ProfilerConfig`` instance.
3437

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
validate_source_code_input_against_pipeline_variables,
5151
)
5252
from sagemaker.inputs import TrainingInput, FileSystemInput
53+
from sagemaker.instance_group import InstanceGroup
5354
from sagemaker.job import _Job
5455
from sagemaker.jumpstart.utils import (
5556
add_jumpstart_tags,
@@ -149,7 +150,7 @@ def __init__(
149150
code_location: Optional[str] = None,
150151
entry_point: Optional[Union[str, PipelineVariable]] = None,
151152
dependencies: Optional[List[Union[str]]] = None,
152-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
153+
instance_groups: Optional[List[InstanceGroup]] = None,
153154
**kwargs,
154155
):
155156
"""Initialize an ``EstimatorBase`` instance.
@@ -1586,6 +1587,8 @@ def _get_instance_type(self):
15861587

15871588
for instance_group in self.instance_groups:
15881589
instance_type = instance_group.instance_type
1590+
if is_pipeline_variable(instance_type):
1591+
continue
15891592
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
15901593

15911594
if match:
@@ -2185,7 +2188,7 @@ def __init__(
21852188
code_location: Optional[str] = None,
21862189
entry_point: Optional[Union[str, PipelineVariable]] = None,
21872190
dependencies: Optional[List[str]] = None,
2188-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
2191+
instance_groups: Optional[List[InstanceGroup]] = None,
21892192
**kwargs,
21902193
):
21912194
"""Initialize an ``Estimator`` instance.
@@ -2880,7 +2883,15 @@ def _validate_and_set_debugger_configs(self):
28802883
# Disable debugger if checkpointing is enabled by the customer
28812884
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
28822885
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2883-
if self.instance_count > 1 or (
2886+
if is_pipeline_variable(self.instance_count):
2887+
logger.warning(
2888+
"SMDebug does not currently support distributed training jobs "
2889+
"with checkpointing enabled. Therefore, to allow parameterized "
2890+
"instance_count and allow to change it to any values in execution time, "
2891+
"the debugger_hook_config is disabled."
2892+
)
2893+
self.debugger_hook_config = False
2894+
elif self.instance_count > 1 or (
28842895
hasattr(self, "distribution")
28852896
and self.distribution is not None # pylint: disable=no-member
28862897
):

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Configuration for the SageMaker Training Compiler."""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Union
1617

1718
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
19+
from sagemaker.workflow.entities import PipelineVariable
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig):
2628

2729
def __init__(
2830
self,
29-
enabled=True,
30-
debug=False,
31+
enabled: Union[bool, PipelineVariable] = True,
32+
debug: Union[bool, PipelineVariable] = False,
3133
):
3234
"""This class initializes a ``TrainingCompilerConfig`` instance.
3335

src/sagemaker/inputs.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16+
from typing import Union, Optional, List
1617
import attr
1718

19+
from sagemaker.workflow.entities import PipelineVariable
20+
1821
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
1922
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2023

@@ -29,17 +32,17 @@ class TrainingInput(object):
2932

3033
def __init__(
3134
self,
32-
s3_data,
33-
distribution=None,
34-
compression=None,
35-
content_type=None,
36-
record_wrapping=None,
37-
s3_data_type="S3Prefix",
38-
instance_groups=None,
39-
input_mode=None,
40-
attribute_names=None,
41-
target_attribute_name=None,
42-
shuffle_config=None,
35+
s3_data: Union[str, PipelineVariable],
36+
distribution: Optional[Union[str, PipelineVariable]] = None,
37+
compression: Optional[Union[str, PipelineVariable]] = None,
38+
content_type: Optional[Union[str, PipelineVariable]] = None,
39+
record_wrapping: Optional[Union[str, PipelineVariable]] = None,
40+
s3_data_type: Union[str, PipelineVariable] = "S3Prefix",
41+
instance_groups: Optional[List[Union[str, PipelineVariable]]] = None,
42+
input_mode: Optional[Union[str, PipelineVariable]] = None,
43+
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
44+
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
45+
shuffle_config: Optional["ShuffleConfig"] = None,
4346
):
4447
r"""Create a definition for input data used by an SageMaker training job.
4548

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

0 commit comments

Comments
 (0)