Skip to content

feat: Support KeepAlivePeriodInSeconds for Training APIs #3371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
role: str,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
volume_size: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
Expand Down Expand Up @@ -167,6 +168,9 @@ def __init__(
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
not set.
keep_alive_period_in_seconds (int): The duration of time in seconds
to retain configured resources in a warm pool for subsequent
training jobs (default: None).
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
storing input and output data during training (default: 30).

Expand Down Expand Up @@ -510,6 +514,7 @@ def __init__(
self.role = role
self.instance_count = instance_count
self.instance_type = instance_type
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
self.instance_groups = instance_groups
self.volume_size = volume_size
self.volume_kms_key = volume_kms_key
Expand Down Expand Up @@ -1578,6 +1583,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
if "EnableNetworkIsolation" in job_details:
init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]

if "KeepAlivePeriodInSeconds" in job_details["ResourceConfig"]:
init_params["keep_alive_period_in_seconds"] = job_details["ResourceConfig"][
"keepAlivePeriodInSeconds"
]

has_hps = "HyperParameters" in job_details
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

Expand Down Expand Up @@ -2126,7 +2136,9 @@ def _is_local_channel(cls, input_uri):
return isinstance(input_uri, string_types) and input_uri.startswith("file://")

@classmethod
def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
def update(
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
):
"""Update a running Amazon SageMaker training job.

Args:
Expand All @@ -2135,18 +2147,23 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
updated in the training job. (default: None).
profiler_config (dict): Configuration for how profiling information is emitted with
SageMaker Debugger. (default: None).
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: None).

Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the updated training job.
"""
update_args = cls._get_update_args(estimator, profiler_rule_configs, profiler_config)
update_args = cls._get_update_args(
estimator, profiler_rule_configs, profiler_config, resource_config
)
estimator.sagemaker_session.update_training_job(**update_args)

return estimator.latest_training_job

@classmethod
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.

Args:
Expand All @@ -2156,13 +2173,17 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
updated in the training job. (default: None).
profiler_config (dict): Configuration for how profiling information is emitted with
SageMaker Debugger. (default: None).
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: None).

Returns:
Dict: dict for `sagemaker.session.Session.update_training_job` method
"""
update_args = {"job_name": estimator.latest_training_job.name}
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
update_args.update(build_dict("profiler_config", profiler_config))
update_args.update(build_dict("resource_config", resource_config))

return update_args

Expand Down Expand Up @@ -2218,6 +2239,7 @@ def __init__(
role: str,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
volume_size: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
Expand Down Expand Up @@ -2270,6 +2292,9 @@ def __init__(
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
not set.
keep_alive_period_in_seconds (int): The duration of time in seconds
to retain configured resources in a warm pool for subsequent
training jobs (default: None).
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
storing input and output data during training (default: 30).

Expand Down Expand Up @@ -2591,6 +2616,7 @@ def __init__(
role,
instance_count,
instance_type,
keep_alive_period_in_seconds,
volume_size,
volume_kms_key,
max_run,
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.instance_groups,
estimator.volume_size,
estimator.volume_kms_key,
estimator.keep_alive_period_in_seconds,
)
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
vpc_config = estimator.get_vpc_config()
Expand Down Expand Up @@ -281,14 +282,21 @@ def _prepare_output_config(s3_path, kms_key_id):

@staticmethod
def _prepare_resource_config(
instance_count, instance_type, instance_groups, volume_size, volume_kms_key
instance_count,
instance_type,
instance_groups,
volume_size,
volume_kms_key,
keep_alive_period_in_seconds,
):
"""Placeholder docstring"""
resource_config = {
"VolumeSizeInGB": volume_size,
}
if volume_kms_key is not None:
resource_config["VolumeKmsKeyId"] = volume_kms_key
if keep_alive_period_in_seconds is not None:
resource_config["KeepAlivePeriodInSeconds"] = keep_alive_period_in_seconds
if instance_groups is not None:
if instance_count is not None or instance_type is not None:
raise ValueError(
Expand Down
14 changes: 13 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def update_training_job(
job_name,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
):
"""Calls the UpdateTrainingJob API for the given job name and returns the response.

Expand All @@ -829,11 +830,15 @@ def update_training_job(
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
profiler_config(dict): Configuration for how profiling information is emitted with
SageMaker Profiler. (default: ``None``).
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).
"""
update_training_job_request = self._get_update_training_job_request(
job_name=job_name,
profiler_rule_configs=profiler_rule_configs,
profiler_config=profiler_config,
resource_config=resource_config,
)
LOGGER.info("Updating training job with name %s", job_name)
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
Expand All @@ -844,14 +849,18 @@ def _get_update_training_job_request(
job_name,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
):
"""Constructs a request compatible for updateing an Amazon SageMaker training job.
"""Constructs a request compatible for updating an Amazon SageMaker training job.

Args:
job_name (str): Name of the training job being updated.
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
profiler_config(dict): Configuration for how profiling information is emitted with
SageMaker Profiler. (default: ``None``).
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).

Returns:
Dict: an update training request dict
Expand All @@ -866,6 +875,9 @@ def _get_update_training_job_request(
if profiler_config is not None:
update_training_job_request["ProfilerConfig"] = profiler_config

if resource_config is not None:
update_training_job_request["ResourceConfig"] = resource_config

return update_training_job_request

def process(
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
BUCKET_NAME = "mybucket"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "c4.4xlarge"
KEEP_ALIVE_PERIOD_IN_SECONDS = 1800
ACCELERATOR_TYPE = "ml.eia.medium"
ROLE = "DummyRole"
IMAGE_URI = "fakeimage"
Expand Down Expand Up @@ -351,6 +352,23 @@ def test_framework_with_heterogeneous_cluster(sagemaker_session):
}


def test_framework_with_keep_alive_period(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS


def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
base_config=rule_configs.stalled_training_rule(),
Expand Down
22 changes: 20 additions & 2 deletions tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LOCAL_FILE_NAME = "file://local/file"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "c4.4xlarge"
KEEP_ALIVE_PERIOD = 1800
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
VOLUME_SIZE = 1
MAX_RUNTIME = 1
Expand Down Expand Up @@ -599,7 +600,7 @@ def test_prepare_output_config_kms_key_none():

def test_prepare_resource_config():
resource_config = _Job._prepare_resource_config(
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None
)

assert resource_config == {
Expand All @@ -609,9 +610,23 @@ def test_prepare_resource_config():
}


def test_prepare_resource_config_with_keep_alive_period():
resource_config = _Job._prepare_resource_config(
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD
)

assert resource_config == {
"InstanceCount": INSTANCE_COUNT,
"InstanceType": INSTANCE_TYPE,
"VolumeSizeInGB": VOLUME_SIZE,
"VolumeKmsKeyId": VOLUME_KMS_KEY,
"KeepAlivePeriodInSeconds": KEEP_ALIVE_PERIOD,
}


def test_prepare_resource_config_with_volume_kms():
resource_config = _Job._prepare_resource_config(
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None
)

assert resource_config == {
Expand All @@ -629,6 +644,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster():
[InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2)],
VOLUME_SIZE,
None,
None,
)

assert resource_config == {
Expand All @@ -648,6 +664,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
[INSTANCE_GROUP],
VOLUME_SIZE,
None,
None,
)
assert "instance_count and instance_type cannot be set when instance_groups is set" in str(
error
Expand All @@ -662,6 +679,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
None,
VOLUME_SIZE,
None,
None,
)
assert "instance_count and instance_type must be set if instance_groups is not set" in str(
error
Expand Down