Skip to content

fix: hyperparameter tuning with spot instances and checkpoints #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def compile_model(
LOGGER.info("Creating compilation-job with name: %s", job_name)
self.sagemaker_client.create_compilation_job(**compilation_job_request)

def tune(
def tune( # noqa: C901
self,
job_name,
strategy,
Expand All @@ -450,6 +450,9 @@ def tune(
early_stopping_type="Off",
encrypt_inter_container_traffic=False,
vpc_config=None,
train_use_spot_instances=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we allow the user to provide a train_max_wait similar to the Estimator call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to pass train_max_wait to HyperparameterTuner. It has already been passed as stop_condition from the estimator of a HyperparameterTuner instance.

estimator.train_max_run, estimator.train_max_wait

This is called from here:

config = _Job._load_config(inputs, tuner.estimator)

checkpoint_s3_uri=None,
checkpoint_local_path=None,
):
"""Create an Amazon SageMaker hyperparameter tuning job

Expand Down Expand Up @@ -512,6 +515,18 @@ def tune(
The key in vpc_config is 'Subnets'.
* security_group_ids (list[str]): List of security group ids.
The key in vpc_config is 'SecurityGroupIds'.
train_use_spot_instances (bool): whether to use spot instances for training.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).

"""
tune_request = {
Expand Down Expand Up @@ -569,6 +584,15 @@ def tune(
if encrypt_inter_container_traffic:
tune_request["TrainingJobDefinition"]["EnableInterContainerTrafficEncryption"] = True

if train_use_spot_instances:
tune_request["TrainingJobDefinition"]["EnableManagedSpotTraining"] = True

if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
if checkpoint_local_path:
checkpoint_config["LocalPath"] = checkpoint_local_path
tune_request["TrainingJobDefinition"]["CheckpointConfig"] = checkpoint_config

LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,10 @@ def start_new(cls, tuner, inputs):
"encrypt_inter_container_traffic"
] = tuner.estimator.encrypt_inter_container_traffic

tuner_args["train_use_spot_instances"] = tuner.estimator.train_use_spot_instances
tuner_args["checkpoint_s3_uri"] = tuner.estimator.checkpoint_s3_uri
tuner_args["checkpoint_local_path"] = tuner.estimator.checkpoint_local_path

tuner.estimator.sagemaker_session.tune(**tuner_args)

return cls(tuner.sagemaker_session, tuner._current_job_name)
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,51 @@ def assert_create_tuning_job_request(**kwrags):
)


def test_tune_with_spot_and_checkpoints(sagemaker_session):
def assert_create_tuning_job_request(**kwrags):
assert (
kwrags["HyperParameterTuningJobConfig"]
== SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"]
)
assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1"
assert kwrags["TrainingJobDefinition"]["EnableManagedSpotTraining"] is True
assert (
kwrags["TrainingJobDefinition"]["CheckpointConfig"]["S3Uri"]
== "s3://mybucket/checkpoints/"
)
assert (
kwrags["TrainingJobDefinition"]["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
)
assert kwrags.get("WarmStartConfig", None) is None

sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = (
assert_create_tuning_job_request
)
sagemaker_session.tune(
job_name="dummy-tuning-1",
strategy="Bayesian",
objective_type="Maximize",
objective_metric_name="val-score",
max_jobs=100,
max_parallel_jobs=5,
parameter_ranges=SAMPLE_PARAM_RANGES,
static_hyperparameters=STATIC_HPs,
image="dummy-image-1",
input_mode="File",
metric_definitions=SAMPLE_METRIC_DEF,
role=EXPANDED_ROLE,
input_config=SAMPLE_INPUT,
output_config=SAMPLE_OUTPUT,
resource_config=RESOURCE_CONFIG,
stop_condition=SAMPLE_STOPPING_CONDITION,
tags=None,
warm_start_config=None,
train_use_spot_instances=True,
checkpoint_s3_uri="s3://mybucket/checkpoints/",
checkpoint_local_path="/tmp/checkpoints",
)


def test_stop_tuning_job(sagemaker_session):
sms = sagemaker_session
sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(
Expand Down