diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 09e77d612a..eaf4644da6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -16,6 +16,7 @@ import json import logging import tempfile +from typing import Union from six.moves.urllib.parse import urlparse @@ -27,6 +28,7 @@ from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.utils import sagemaker_timestamp +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline logger = logging.getLogger(__name__) @@ -304,7 +306,12 @@ class RecordSet(object): """Placeholder docstring""" def __init__( - self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train" + self, + s3_data: Union[str, PipelineVariable], + num_records: int, + feature_dim: int, + s3_data_type: Union[str, PipelineVariable] = "ManifestFile", + channel: Union[str, PipelineVariable] = "train", ): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. diff --git a/src/sagemaker/debugger/debugger.py b/src/sagemaker/debugger/debugger.py index d2d53547f1..23f7b651a3 100644 --- a/src/sagemaker/debugger/debugger.py +++ b/src/sagemaker/debugger/debugger.py @@ -24,12 +24,15 @@ from abc import ABC +from typing import Union, Optional, List, Dict + import attr import smdebug_rulesconfig as rule_configs from sagemaker import image_uris from sagemaker.utils import build_dict +from sagemaker.workflow.entities import PipelineVariable framework_name = "debugger" DEBUGGER_FLAG = "USE_SMDEBUG" @@ -311,17 +314,17 @@ def sagemaker( @classmethod def custom( cls, - name, - image_uri, - instance_type, - volume_size_in_gb, - source=None, - rule_to_invoke=None, - container_local_output_path=None, - s3_output_path=None, - other_trials_s3_input_paths=None, - rule_parameters=None, - collections_to_save=None, + name: str, + image_uri: Union[str, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable], + source: Optional[str] = None, + rule_to_invoke: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None, + rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collections_to_save: Optional[List["CollectionConfig"]] = None, actions=None, ): """Initialize a ``Rule`` object for a *custom* debugging rule. @@ -610,10 +613,10 @@ class DebuggerHookConfig(object): def __init__( self, - s3_output_path=None, - container_local_output_path=None, - hook_parameters=None, - collection_configs=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collection_configs: Optional[List["CollectionConfig"]] = None, ): """Initialize the DebuggerHookConfig instance. @@ -679,7 +682,11 @@ def _to_request_dict(self): class TensorBoardOutputConfig(object): """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" - def __init__(self, s3_output_path, container_local_output_path=None): + def __init__( + self, + s3_output_path: Union[str, PipelineVariable], + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + ): """Initialize the TensorBoardOutputConfig instance. Args: @@ -708,7 +715,11 @@ def _to_request_dict(self): class CollectionConfig(object): """Creates tensor collections for SageMaker Debugger.""" - def __init__(self, name, parameters=None): + def __init__( + self, + name: Union[str, PipelineVariable], + parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + ): """Constructor for collection configuration. Args: diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 371d161bbe..807ba91e79 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -13,7 +13,10 @@ """Configuration for collecting system and framework metrics in SageMaker training jobs.""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker.debugger.framework_profile import FrameworkProfile +from sagemaker.workflow.entities import PipelineVariable class ProfilerConfig(object): @@ -26,9 +29,9 @@ class ProfilerConfig(object): def __init__( self, - s3_output_path=None, - system_monitor_interval_millis=None, - framework_profile_params=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, + framework_profile_params: Optional[FrameworkProfile] = None, ): """Initialize a ``ProfilerConfig`` instance. diff --git a/src/sagemaker/drift_check_baselines.py b/src/sagemaker/drift_check_baselines.py index 24aa4787d0..9c3b8dbd57 100644 --- a/src/sagemaker/drift_check_baselines.py +++ b/src/sagemaker/drift_check_baselines.py @@ -13,21 +13,25 @@ """This file contains code related to drift check baselines""" from __future__ import absolute_import +from typing import Optional + +from sagemaker.model_metrics import MetricsSource, FileSource + class DriftCheckBaselines(object): """Accepts drift check baselines parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias_config_file=None, - bias_pre_training_constraints=None, - bias_post_training_constraints=None, - explainability_constraints=None, - explainability_config_file=None, + model_statistics: Optional[MetricsSource] = None, + model_constraints: Optional[MetricsSource] = None, + model_data_statistics: Optional[MetricsSource] = None, + model_data_constraints: Optional[MetricsSource] = None, + bias_config_file: Optional[FileSource] = None, + bias_pre_training_constraints: Optional[MetricsSource] = None, + bias_post_training_constraints: Optional[MetricsSource] = None, + explainability_constraints: Optional[MetricsSource] = None, + explainability_config_file: Optional[FileSource] = None, ): """Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9d0c30ff27..a4b769a306 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -50,6 +50,7 @@ validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -149,7 +150,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -1580,6 +1581,8 @@ def _get_instance_type(self): for instance_group in self.instance_groups: instance_type = instance_group.instance_type + if is_pipeline_variable(instance_type): + continue match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) if match: @@ -2179,7 +2182,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2874,7 +2877,15 @@ def _validate_and_set_debugger_configs(self): # Disable debugger if checkpointing is enabled by the customer if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: if self._framework_name in {"mxnet", "pytorch", "tensorflow"}: - if self.instance_count > 1 or ( + if is_pipeline_variable(self.instance_count): + logger.warning( + "SMDebug does not currently support distributed training jobs " + "with checkpointing enabled. Therefore, to allow parameterized " + "instance_count and allow to change it to any values in execution time, " + "the debugger_hook_config is disabled." + ) + self.debugger_hook_config = False + elif self.instance_count > 1 or ( hasattr(self, "distribution") and self.distribution is not None # pylint: disable=no-member ): diff --git a/src/sagemaker/huggingface/training_compiler/config.py b/src/sagemaker/huggingface/training_compiler/config.py index 07a3bcf9b7..b19fb2be2b 100644 --- a/src/sagemaker/huggingface/training_compiler/config.py +++ b/src/sagemaker/huggingface/training_compiler/config.py @@ -13,8 +13,10 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 3481c138bd..0fca307a97 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -13,8 +13,11 @@ """Amazon SageMaker channel configurations for S3 data sources and file system data sources""" from __future__ import absolute_import, print_function +from typing import Union, Optional, List import attr +from sagemaker.workflow.entities import PipelineVariable + FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] @@ -29,17 +32,17 @@ class TrainingInput(object): def __init__( self, - s3_data, - distribution=None, - compression=None, - content_type=None, - record_wrapping=None, - s3_data_type="S3Prefix", - instance_groups=None, - input_mode=None, - attribute_names=None, - target_attribute_name=None, - shuffle_config=None, + s3_data: Union[str, PipelineVariable], + distribution: Optional[Union[str, PipelineVariable]] = None, + compression: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, + record_wrapping: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + instance_groups: Optional[List[Union[str, PipelineVariable]]] = None, + input_mode: Optional[Union[str, PipelineVariable]] = None, + attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, + target_attribute_name: Optional[Union[str, PipelineVariable]] = None, + shuffle_config: Optional["ShuffleConfig"] = None, ): r"""Create a definition for input data used by an SageMaker training job. diff --git a/src/sagemaker/metadata_properties.py b/src/sagemaker/metadata_properties.py index 4bc77ed0ee..b25aff9168 100644 --- a/src/sagemaker/metadata_properties.py +++ b/src/sagemaker/metadata_properties.py @@ -13,16 +13,20 @@ """This file contains code related to metadata properties.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class MetadataProperties(object): """Accepts metadata properties parameters for conversion to request dict.""" def __init__( self, - commit_id=None, - repository=None, - generated_by=None, - project_id=None, + commit_id: Optional[Union[str, PipelineVariable]] = None, + repository: Optional[Union[str, PipelineVariable]] = None, + generated_by: Optional[Union[str, PipelineVariable]] = None, + project_id: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetadataProperties`` instance and turn parameters into dict. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a2c6da4bb7..8772fa724f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,7 +18,7 @@ import logging import os import copy -from typing import List, Dict +from typing import List, Dict, Optional, Union import sagemaker from sagemaker import ( @@ -29,7 +29,11 @@ utils, git_utils, ) +from sagemaker.session import Session +from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import PredictorBase from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer @@ -37,10 +41,12 @@ from sagemaker.utils import ( unique_name_from_base, update_container_with_inference_params, + to_string, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession LOGGER = logging.getLogger("sagemaker") @@ -82,23 +88,23 @@ class Model(ModelBase): def __init__( self, - image_uri, - model_data=None, - role=None, - predictor_cls=None, - env=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, - model_kms_key=None, - image_config=None, - source_dir=None, - code_location=None, - entry_point=None, - container_log_level=logging.INFO, - dependencies=None, - git_config=None, + image_uri: Union[str, PipelineVariable], + model_data: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + model_kms_key: Optional[str] = None, + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + source_dir: Optional[str] = None, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, ): """Initialize an SageMaker ``Model``. @@ -298,28 +304,28 @@ def __init__( @runnable_by_pipeline def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - validation_specification=None, - domain=None, - task=None, - sample_payload_url=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + validation_specification: Optional[Union[str, PipelineVariable]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -349,11 +355,11 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). - sample_payload_url (str): The S3 path where the sample payload is stored - (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). framework (str): Machine learning framework of the model package container image (default: None). framework_version (str): Framework version of the Model Package Container Image @@ -421,10 +427,10 @@ def register( @runnable_by_pipeline def create( self, - instance_type: str = None, - accelerator_type: str = None, - serverless_inference_config: ServerlessInferenceConfig = None, - tags: List[Dict[str, str]] = None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, ): """Create a SageMaker Model Entity @@ -608,7 +614,7 @@ def _script_mode_env_vars(self): return { SCRIPT_PARAM_NAME.upper(): script_name or str(), DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } @@ -1286,19 +1292,19 @@ class FrameworkModel(Model): def __init__( self, - model_data, - image_uri, - role, - entry_point, - source_dir=None, - predictor_cls=None, - env=None, - name=None, - container_log_level=logging.INFO, - code_location=None, - sagemaker_session=None, - dependencies=None, - git_config=None, + model_data: Union[str, PipelineVariable], + image_uri: Union[str, PipelineVariable], + role: str, + entry_point: str, + source_dir: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, **kwargs, ): """Initialize a ``FrameworkModel``. diff --git a/src/sagemaker/model_metrics.py b/src/sagemaker/model_metrics.py index acce4e13c9..83a43d3f18 100644 --- a/src/sagemaker/model_metrics.py +++ b/src/sagemaker/model_metrics.py @@ -13,20 +13,24 @@ """This file contains code related to model metrics, including metric source and file source.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class ModelMetrics(object): """Accepts model metrics parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias=None, - explainability=None, - bias_pre_training=None, - bias_post_training=None, + model_statistics: Optional["MetricsSource"] = None, + model_constraints: Optional["MetricsSource"] = None, + model_data_statistics: Optional["MetricsSource"] = None, + model_data_constraints: Optional["MetricsSource"] = None, + bias: Optional["MetricsSource"] = None, + explainability: Optional["MetricsSource"] = None, + bias_pre_training: Optional["MetricsSource"] = None, + bias_post_training: Optional["MetricsSource"] = None, ): """Initialize a ``ModelMetrics`` instance and turn parameters into dict. @@ -99,9 +103,9 @@ class MetricsSource(object): def __init__( self, - content_type, - s3_uri, - content_digest=None, + content_type: Union[str, PipelineVariable], + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetricsSource`` instance and turn parameters into dict. @@ -127,9 +131,9 @@ class FileSource(object): def __init__( self, - s3_uri, - content_digest=None, - content_type=None, + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``FileSource`` instance and turn parameters into dict. diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py index 39950f4f84..adc98a319a 100644 --- a/src/sagemaker/serverless/serverless_inference_config.py +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object): def __init__( self, - memory_size_in_mb=2048, - max_concurrency=5, + memory_size_in_mb: int = 2048, + max_concurrency: int = 5, ): """Initialize a ServerlessInferenceConfig object for serverless inference configuration. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 145bf41cbe..221434d7db 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2633,7 +2633,9 @@ def _create_model_request( request["VpcConfig"] = vpc_config if enable_network_isolation: - request["EnableNetworkIsolation"] = True + # enable_network_isolation may be a pipeline variable which is + # parsed in execution time + request["EnableNetworkIsolation"] = enable_network_isolation return request diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 4db647e140..9533f475a1 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, Dict from packaging import version @@ -27,6 +28,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow import is_pipeline_variable from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -41,12 +43,12 @@ class TensorFlow(Framework): def __init__( self, - py_version=None, - framework_version=None, - model_dir=None, - image_uri=None, - distribution=None, - compiler_config=None, + py_version: Optional[str] = None, + framework_version: Optional[str] = None, + model_dir: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict[str, str]] = None, + compiler_config: Optional[TrainingCompilerConfig] = None, **kwargs, ): """Initialize a ``TensorFlow`` estimator. @@ -251,6 +253,8 @@ def _only_legacy_mode_supported(self): def _only_python_3_supported(self): """Placeholder docstring""" + if not self.framework_version: + return False return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION @classmethod