diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index f3caa2e8bb..c61a727779 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -82,6 +82,10 @@ def __init__( model_channel_name="model", metric_definitions=None, encrypt_inter_container_traffic=False, + train_use_spot_instances=False, + train_max_wait=None, + checkpoint_s3_uri=None, + checkpoint_local_path=None, ): """Initialize an ``EstimatorBase`` instance. @@ -157,6 +161,28 @@ def __init__( encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). + train_use_spot_instances (bool): Specifies whether to use SageMaker + Managed Spot instances for training. If enabled then the + `train_max_wait` arg should also be set. + + More information: + https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html + (default: ``False``). + train_max_wait (int): Timeout in seconds waiting for spot training + instances (default: None). After this amount of time Amazon + SageMaker will stop waiting for Spot instances to become + available (default: ``None``). + checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints + that the algorithm persists (if any) during training. (default: + ``None``). + checkpoint_local_path (str): The local path that the algorithm + writes its checkpoints to. SageMaker will persist all files + under this path to `checkpoint_s3_uri` continually during + training. On job startup the reverse happens - data from the + s3 location is downloaded to this path before the algorithm is + started. If the path is unset then SageMaker assumes the + checkpoints will be provided under `/opt/ml/checkpoints/`. + (default: ``None``). """ self.role = role self.train_instance_count = train_instance_count @@ -199,6 +225,10 @@ def __init__( self.security_group_ids = security_group_ids self.encrypt_inter_container_traffic = encrypt_inter_container_traffic + self.train_use_spot_instances = train_use_spot_instances + self.train_max_wait = train_max_wait + self.checkpoint_s3_uri = checkpoint_s3_uri + self.checkpoint_local_path = checkpoint_local_path @abstractmethod def train_image(self): @@ -795,10 +825,35 @@ def start_new(cls, estimator, inputs): else: train_args["image"] = estimator.train_image() + cls._add_spot_checkpoint_args(local_mode, estimator, train_args) + estimator.sagemaker_session.train(**train_args) return cls(estimator.sagemaker_session, estimator._current_job_name) + @classmethod + def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args): + """ + Args: + local_mode: + estimator: + train_args: + """ + if estimator.train_use_spot_instances: + if local_mode: + raise ValueError("Spot training is not supported in local mode.") + train_args["train_use_spot_instances"] = True + + if estimator.checkpoint_s3_uri: + if local_mode: + raise ValueError("Setting checkpoint_s3_uri is not supported in local mode.") + train_args["checkpoint_s3_uri"] = estimator.checkpoint_s3_uri + + if estimator.checkpoint_local_path: + if local_mode: + raise ValueError("Setting checkpoint_local_path is not supported in local mode.") + train_args["checkpoint_local_path"] = estimator.checkpoint_local_path + @classmethod def _is_local_channel(cls, input_uri): """ @@ -845,6 +900,10 @@ def __init__( model_channel_name="model", metric_definitions=None, encrypt_inter_container_traffic=False, + train_use_spot_instances=False, + train_max_wait=None, + checkpoint_s3_uri=None, + checkpoint_local_path=None, ): """Initialize an ``Estimator`` instance. @@ -926,6 +985,28 @@ def __init__( encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). + train_use_spot_instances (bool): Specifies whether to use SageMaker + Managed Spot instances for training. If enabled then the + `train_max_wait` arg should also be set. + + More information: + https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html + (default: ``False``). + train_max_wait (int): Timeout in seconds waiting for spot training + instances (default: None). After this amount of time Amazon + SageMaker will stop waiting for Spot instances to become + available (default: ``None``). + checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints + that the algorithm persists (if any) during training. (default: + ``None``). + checkpoint_local_path (str): The local path that the algorithm + writes its checkpoints to. SageMaker will persist all files + under this path to `checkpoint_s3_uri` continually during + training. On job startup the reverse happens - data from the + s3 location is downloaded to this path before the algorithm is + started. If the path is unset then SageMaker assumes the + checkpoints will be provided under `/opt/ml/checkpoints/`. + (default: ``None``). """ self.image_name = image_name self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} @@ -948,6 +1029,10 @@ def __init__( model_channel_name=model_channel_name, metric_definitions=metric_definitions, encrypt_inter_container_traffic=encrypt_inter_container_traffic, + train_use_spot_instances=train_use_spot_instances, + train_max_wait=train_max_wait, + checkpoint_s3_uri=checkpoint_s3_uri, + checkpoint_local_path=checkpoint_local_path, ) def train_image(self): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 0b7327757a..6f8a1af028 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -80,7 +80,9 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): estimator.train_volume_size, estimator.train_volume_kms_key, ) - stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) + stop_condition = _Job._prepare_stop_condition( + estimator.train_max_run, estimator.train_max_wait + ) vpc_config = estimator.get_vpc_config() model_channel = _Job._prepare_channel( @@ -312,11 +314,14 @@ def _prepare_resource_config(instance_count, instance_type, volume_size, train_v return resource_config @staticmethod - def _prepare_stop_condition(max_run): + def _prepare_stop_condition(max_run, max_wait): """ Args: max_run: + max_wait: """ + if max_wait: + return {"MaxRuntimeInSeconds": max_run, "MaxWaitTimeInSeconds": max_wait} return {"MaxRuntimeInSeconds": max_run} @property diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index fcbe6f3735..a3426e0ddf 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -257,6 +257,9 @@ def train( # noqa: C901 image=None, algorithm_arn=None, encrypt_inter_container_traffic=False, + train_use_spot_instances=False, + checkpoint_s3_uri=None, + checkpoint_local_path=None, ): """Create an Amazon SageMaker training job. @@ -307,6 +310,18 @@ def train( # noqa: C901 algorithm_arn (str): Algorithm Arn from Marketplace. encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). + train_use_spot_instances (bool): whether to use spot instances for training. + checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints + that the algorithm persists (if any) during training. (default: + ``None``). + checkpoint_local_path (str): The local path that the algorithm + writes its checkpoints to. SageMaker will persist all files + under this path to `checkpoint_s3_uri` continually during + training. On job startup the reverse happens - data from the + s3 location is downloaded to this path before the algorithm is + started. If the path is unset then SageMaker assumes the + checkpoints will be provided under `/opt/ml/checkpoints/`. + (default: ``None``). Returns: str: ARN of the training job, if it is created. @@ -357,6 +372,15 @@ def train( # noqa: C901 if encrypt_inter_container_traffic: train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic + if train_use_spot_instances: + train_request["EnableManagedSpotTraining"] = train_use_spot_instances + + if checkpoint_s3_uri: + checkpoint_config = {"S3Uri": checkpoint_s3_uri} + if checkpoint_local_path: + checkpoint_config["LocalPath"] = checkpoint_local_path + train_request["CheckpointConfig"] = checkpoint_config + LOGGER.info("Creating training-job with name: %s", job_name) LOGGER.debug("train request: %s", json.dumps(train_request, indent=4)) self.sagemaker_client.create_training_job(**train_request) @@ -1468,10 +1492,15 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method print() # Customers are not billed for hardware provisioning, so billable time is less than # total time - billable_time = ( - description["TrainingEndTime"] - description["TrainingStartTime"] - ) * instance_count - print("Billable seconds:", int(billable_time.total_seconds()) + 1) + training_time = description.get("TrainingTimeInSeconds") + billable_time = description.get("BillableTimeInSeconds") + if training_time is not None: + print("Training seconds:", training_time * instance_count) + if billable_time is not None: + print("Billable seconds:", billable_time * instance_count) + if description.get("EnableManagedSpotTraining"): + saving = (1 - float(billable_time) / training_time) * 100 + print("Managed Spot Training savings: {:.1f}%".format(saving)) def container_def(image, model_data_url=None, env=None): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 766729c341..1b075f0e43 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -227,6 +227,69 @@ def test_framework_all_init_args(sagemaker_session): } +def test_framework_with_spot_and_checkpoints(sagemaker_session): + f = DummyFramework( + "my_script.py", + role="DummyRole", + train_instance_count=3, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + train_volume_size=123, + train_volume_kms_key="volumekms", + train_max_run=456, + input_mode="inputmode", + output_path="outputpath", + output_kms_key="outputkms", + base_job_name="basejobname", + tags=[{"foo": "bar"}], + subnets=["123", "456"], + security_group_ids=["789", "012"], + metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], + encrypt_inter_container_traffic=True, + train_use_spot_instances=True, + train_max_wait=500, + checkpoint_s3_uri="s3://mybucket/checkpoints/", + checkpoint_local_path="/tmp/checkpoints", + ) + _TrainingJob.start_new(f, "s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args == { + "input_mode": "inputmode", + "tags": [{"foo": "bar"}], + "hyperparameters": {}, + "image": "fakeimage", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://mydata", + } + }, + } + ], + "output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"}, + "vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]}, + "stop_condition": {"MaxRuntimeInSeconds": 456, "MaxWaitTimeInSeconds": 500}, + "role": sagemaker_session.expand_role(), + "job_name": None, + "resource_config": { + "VolumeSizeInGB": 123, + "InstanceCount": 3, + "VolumeKmsKeyId": "volumekms", + "InstanceType": "ml.m4.xlarge", + }, + "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], + "encrypt_inter_container_traffic": True, + "train_use_spot_instances": True, + "checkpoint_s3_uri": "s3://mybucket/checkpoints/", + "checkpoint_local_path": "/tmp/checkpoints", + } + + def test_framework_init_s3_entry_point_invalid(sagemaker_session): with pytest.raises(ValueError) as error: DummyFramework( diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 954c9e39f8..0eacd160f0 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -563,10 +563,22 @@ def test_prepare_resource_config_with_volume_kms(): def test_prepare_stop_condition(): max_run = 1 + max_wait = 2 - stop_condition = _Job._prepare_stop_condition(max_run) + stop_condition = _Job._prepare_stop_condition(max_run, max_wait) assert stop_condition["MaxRuntimeInSeconds"] == max_run + assert stop_condition["MaxWaitTimeInSeconds"] == max_wait + + +def test_prepare_stop_condition_no_wait(): + max_run = 1 + max_wait = None + + stop_condition = _Job._prepare_stop_condition(max_run, max_wait) + + assert stop_condition["MaxRuntimeInSeconds"] == max_run + assert "MaxWaitTimeInSeconds" not in stop_condition def test_name(sagemaker_session): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2fe4c416fb..be9e1b4ea1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -651,6 +651,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): tags=TAGS, metric_definitions=METRIC_DEFINITONS, encrypt_inter_container_traffic=True, + train_use_spot_instances=True, + checkpoint_s3_uri="s3://mybucket/checkpoints/", + checkpoint_local_path="/tmp/checkpoints", ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -660,6 +663,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): assert actual_train_args["Tags"] == TAGS assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["EnableInterContainerTrafficEncryption"] is True + assert actual_train_args["EnableManagedSpotTraining"] is True + assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/" + assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints" def test_transform_pack_to_request(sagemaker_session):