diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index ec0df519f5..b337e17a4f 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -45,6 +45,7 @@ InstanceGroup, TensorBoardOutputConfig, CheckpointConfig, + MetricDefinition, ) from sagemaker.modules.utils import convert_unassigned_to_none @@ -71,6 +72,7 @@ "Compute", "Networking", "InputData", + "MetricDefinition", ] diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 31decfaca9..6593c59630 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -67,6 +67,7 @@ TensorBoardOutputConfig, CheckpointConfig, InputData, + MetricDefinition, ) from sagemaker.modules.local_core.local_container import _LocalContainer @@ -237,6 +238,7 @@ class ModelTrainer(BaseModel): _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -587,6 +589,7 @@ def train( training_image_config=self.training_image_config, container_entrypoint=container_entrypoint, container_arguments=container_arguments, + metric_definitions=self._metric_definitions, ) resource_config = self.compute._to_resource_config() @@ -976,9 +979,25 @@ def from_recipe( def with_tensorboard_output_config( self, tensorboard_output_config: TensorBoardOutputConfig - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the TensorBoard output configuration. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import TensorBoardOutputConfig + + tensorboard_output_config = TensorBoardOutputConfig( + s3_output_path="s3://bucket-name/tensorboard", + local_path="/opt/ml/output/tensorboard" + ) + + model_trainer = ModelTrainer( + ... + ).with_tensorboard_output_config(tensorboard_output_config) + Args: tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig): The TensorBoard output configuration. @@ -986,9 +1005,24 @@ def with_tensorboard_output_config( self._tensorboard_output_config = tensorboard_output_config return self - def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": + def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412 """Set the retry strategy for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import RetryStrategy + + retry_strategy = RetryStrategy( + maximum_retry_attempts=3, + ) + + model_trainer = ModelTrainer( + ... + ).with_retry_strategy(retry_strategy) + Args: retry_strategy (RetryStrategy): The retry strategy for the training job. @@ -996,9 +1030,26 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": self._retry_strategy = retry_strategy return self - def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer": + def with_infra_check_config( + self, infra_check_config: InfraCheckConfig + ) -> "ModelTrainer": # noqa: D412 """Set the infra check configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import InfraCheckConfig + + infra_check_config = InfraCheckConfig( + enable_infra_check=True, + ) + + model_trainer = ModelTrainer( + ... + ).with_infra_check_config(infra_check_config) + Args: infra_check_config (InfraCheckConfig): The infra check configuration for the training job. @@ -1008,9 +1059,24 @@ def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "Mode def with_session_chaining_config( self, session_chaining_config: SessionChainingConfig - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the session chaining configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import SessionChainingConfig + + session_chaining_config = SessionChainingConfig( + enable_session_tag_chaining=True, + ) + + model_trainer = ModelTrainer( + ... + ).with_session_chaining_config(session_chaining_config + Args: session_chaining_config (SessionChainingConfig): The session chaining configuration for the training job. @@ -1018,12 +1084,58 @@ def with_session_chaining_config( self._session_chaining_config = session_chaining_config return self - def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "ModelTrainer": + def with_remote_debug_config( + self, remote_debug_config: RemoteDebugConfig + ) -> "ModelTrainer": # noqa: D412 """Set the remote debug configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import RemoteDebugConfig + + remote_debug_config = RemoteDebugConfig( + enable_remote_debug=True, + ) + model_trainer = ModelTrainer( + ... + ).with_remote_debug_config(remote_debug_config) + Args: remote_debug_config (RemoteDebugConfig): The remote debug configuration for the training job. """ self._remote_debug_config = remote_debug_config return self + + def with_metric_definitions( + self, metric_definitions: List[MetricDefinition] + ) -> "ModelTrainer": # noqa: D412 + """Set the metric definitions for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import MetricDefinition + + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?)", + ) + ] + + model_trainer = ModelTrainer( + ... + ).with_metric_definitions(metric_definitions) + + Args: + metric_definitions (List[MetricDefinition]): + The metric definitions for the training job. + """ + self._metric_definitions = metric_definitions + return self diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 093da20ab8..a10ca0958e 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -62,6 +62,7 @@ FileSystemDataSource, Channel, DataSource, + MetricDefinition, ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg @@ -654,6 +655,32 @@ def test_remote_debug_config(mock_training_job, modules_session): ) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_metric_definitions(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?);", + ) + ] + + model_trainer = ModelTrainer( + training_image=image_uri, sagemaker_session=modules_session, role=role + ).with_metric_definitions(metric_definitions) + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions + == metric_definitions + ) + + @patch("sagemaker.modules.train.model_trainer._get_unique_name") @patch("sagemaker.modules.train.model_trainer.TrainingJob") def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): @@ -771,6 +798,7 @@ def mock_upload_data(path, bucket, key_prefix): training_input_mode=training_input_mode, training_image=training_image, algorithm_name=None, + metric_definitions=None, container_entrypoint=DEFAULT_ENTRYPOINT, container_arguments=DEFAULT_ARGUMENTS, training_image_config=training_image_config,