Skip to content

Spot training #990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
train_max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Initialize an ``EstimatorBase`` instance.

Expand Down Expand Up @@ -157,6 +161,28 @@ def __init__(
encrypt_inter_container_traffic (bool): Specifies whether traffic
between training containers is encrypted for the training job
(default: ``False``).
train_use_spot_instances (bool): Specifies whether to use SageMaker
Managed Spot instances for training. If enabled then the
`train_max_wait` arg should also be set.

More information:
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
(default: ``False``).
train_max_wait (int): Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default: ``None``).
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
"""
self.role = role
self.train_instance_count = train_instance_count
Expand Down Expand Up @@ -199,6 +225,10 @@ def __init__(
self.security_group_ids = security_group_ids

self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
self.train_use_spot_instances = train_use_spot_instances
self.train_max_wait = train_max_wait
self.checkpoint_s3_uri = checkpoint_s3_uri
self.checkpoint_local_path = checkpoint_local_path

@abstractmethod
def train_image(self):
Expand Down Expand Up @@ -795,10 +825,35 @@ def start_new(cls, estimator, inputs):
else:
train_args["image"] = estimator.train_image()

cls._add_spot_checkpoint_args(local_mode, estimator, train_args)

estimator.sagemaker_session.train(**train_args)

return cls(estimator.sagemaker_session, estimator._current_job_name)

@classmethod
def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
"""
Args:
local_mode:
estimator:
train_args:
"""
if estimator.train_use_spot_instances:
if local_mode:
raise ValueError("Spot training is not supported in local mode.")
train_args["train_use_spot_instances"] = True

if estimator.checkpoint_s3_uri:
if local_mode:
raise ValueError("Setting checkpoint_s3_uri is not supported in local mode.")
train_args["checkpoint_s3_uri"] = estimator.checkpoint_s3_uri

if estimator.checkpoint_local_path:
if local_mode:
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path

@classmethod
def _is_local_channel(cls, input_uri):
"""
Expand Down Expand Up @@ -845,6 +900,10 @@ def __init__(
model_channel_name="model",
metric_definitions=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
train_max_wait=None,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Initialize an ``Estimator`` instance.

Expand Down Expand Up @@ -926,6 +985,28 @@ def __init__(
encrypt_inter_container_traffic (bool): Specifies whether traffic
between training containers is encrypted for the training job
(default: ``False``).
train_use_spot_instances (bool): Specifies whether to use SageMaker
Managed Spot instances for training. If enabled then the
`train_max_wait` arg should also be set.

More information:
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
(default: ``False``).
train_max_wait (int): Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default: ``None``).
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
"""
self.image_name = image_name
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
Expand All @@ -948,6 +1029,10 @@ def __init__(
model_channel_name=model_channel_name,
metric_definitions=metric_definitions,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
train_use_spot_instances=train_use_spot_instances,
train_max_wait=train_max_wait,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
)

def train_image(self):
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.train_volume_size,
estimator.train_volume_kms_key,
)
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
stop_condition = _Job._prepare_stop_condition(
estimator.train_max_run, estimator.train_max_wait
)
vpc_config = estimator.get_vpc_config()

model_channel = _Job._prepare_channel(
Expand Down Expand Up @@ -312,11 +314,14 @@ def _prepare_resource_config(instance_count, instance_type, volume_size, train_v
return resource_config

@staticmethod
def _prepare_stop_condition(max_run):
def _prepare_stop_condition(max_run, max_wait):
"""
Args:
max_run:
max_wait:
"""
if max_wait:
return {"MaxRuntimeInSeconds": max_run, "MaxWaitTimeInSeconds": max_wait}
return {"MaxRuntimeInSeconds": max_run}

@property
Expand Down
37 changes: 33 additions & 4 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def train( # noqa: C901
image=None,
algorithm_arn=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Create an Amazon SageMaker training job.

Expand Down Expand Up @@ -307,6 +310,18 @@ def train( # noqa: C901
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
train_use_spot_instances (bool): whether to use spot instances for training.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).

Returns:
str: ARN of the training job, if it is created.
Expand Down Expand Up @@ -357,6 +372,15 @@ def train( # noqa: C901
if encrypt_inter_container_traffic:
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic

if train_use_spot_instances:
train_request["EnableManagedSpotTraining"] = train_use_spot_instances

if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
if checkpoint_local_path:
checkpoint_config["LocalPath"] = checkpoint_local_path
train_request["CheckpointConfig"] = checkpoint_config

LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
self.sagemaker_client.create_training_job(**train_request)
Expand Down Expand Up @@ -1468,10 +1492,15 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
print()
# Customers are not billed for hardware provisioning, so billable time is less than
# total time
billable_time = (
description["TrainingEndTime"] - description["TrainingStartTime"]
) * instance_count
print("Billable seconds:", int(billable_time.total_seconds()) + 1)
training_time = description.get("TrainingTimeInSeconds")
billable_time = description.get("BillableTimeInSeconds")
if training_time is not None:
print("Training seconds:", training_time * instance_count)
if billable_time is not None:
print("Billable seconds:", billable_time * instance_count)
if description.get("EnableManagedSpotTraining"):
saving = (1 - float(billable_time) / training_time) * 100
print("Managed Spot Training savings: {:.1f}%".format(saving))


def container_def(image, model_data_url=None, env=None):
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,69 @@ def test_framework_all_init_args(sagemaker_session):
}


def test_framework_with_spot_and_checkpoints(sagemaker_session):
f = DummyFramework(
"my_script.py",
role="DummyRole",
train_instance_count=3,
train_instance_type="ml.m4.xlarge",
sagemaker_session=sagemaker_session,
train_volume_size=123,
train_volume_kms_key="volumekms",
train_max_run=456,
input_mode="inputmode",
output_path="outputpath",
output_kms_key="outputkms",
base_job_name="basejobname",
tags=[{"foo": "bar"}],
subnets=["123", "456"],
security_group_ids=["789", "012"],
metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
encrypt_inter_container_traffic=True,
train_use_spot_instances=True,
train_max_wait=500,
checkpoint_s3_uri="s3://mybucket/checkpoints/",
checkpoint_local_path="/tmp/checkpoints",
)
_TrainingJob.start_new(f, "s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args == {
"input_mode": "inputmode",
"tags": [{"foo": "bar"}],
"hyperparameters": {},
"image": "fakeimage",
"input_config": [
{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3DataDistributionType": "FullyReplicated",
"S3Uri": "s3://mydata",
}
},
}
],
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
"stop_condition": {"MaxRuntimeInSeconds": 456, "MaxWaitTimeInSeconds": 500},
"role": sagemaker_session.expand_role(),
"job_name": None,
"resource_config": {
"VolumeSizeInGB": 123,
"InstanceCount": 3,
"VolumeKmsKeyId": "volumekms",
"InstanceType": "ml.m4.xlarge",
},
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
"encrypt_inter_container_traffic": True,
"train_use_spot_instances": True,
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
"checkpoint_local_path": "/tmp/checkpoints",
}


def test_framework_init_s3_entry_point_invalid(sagemaker_session):
with pytest.raises(ValueError) as error:
DummyFramework(
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,22 @@ def test_prepare_resource_config_with_volume_kms():

def test_prepare_stop_condition():
max_run = 1
max_wait = 2

stop_condition = _Job._prepare_stop_condition(max_run)
stop_condition = _Job._prepare_stop_condition(max_run, max_wait)

assert stop_condition["MaxRuntimeInSeconds"] == max_run
assert stop_condition["MaxWaitTimeInSeconds"] == max_wait


def test_prepare_stop_condition_no_wait():
max_run = 1
max_wait = None

stop_condition = _Job._prepare_stop_condition(max_run, max_wait)

assert stop_condition["MaxRuntimeInSeconds"] == max_run
assert "MaxWaitTimeInSeconds" not in stop_condition


def test_name(sagemaker_session):
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
tags=TAGS,
metric_definitions=METRIC_DEFINITONS,
encrypt_inter_container_traffic=True,
train_use_spot_instances=True,
checkpoint_s3_uri="s3://mybucket/checkpoints/",
checkpoint_local_path="/tmp/checkpoints",
)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
Expand All @@ -660,6 +663,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
assert actual_train_args["Tags"] == TAGS
assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS
assert actual_train_args["EnableInterContainerTrafficEncryption"] is True
assert actual_train_args["EnableManagedSpotTraining"] is True
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"


def test_transform_pack_to_request(sagemaker_session):
Expand Down