Skip to content

Commit c226952

Browse files
committed
Support KeepAlivePeriodInSeconds for Training APIs
1 parent f616344 commit c226952

File tree

5 files changed

+82
-6
lines changed

5 files changed

+82
-6
lines changed

src/sagemaker/estimator.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
role: str,
117117
instance_count: Optional[Union[int, PipelineVariable]] = None,
118118
instance_type: Optional[Union[str, PipelineVariable]] = None,
119+
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
119120
volume_size: Union[int, PipelineVariable] = 30,
120121
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
121122
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
@@ -167,6 +168,10 @@ def __init__(
167168
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
168169
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
169170
not set.
171+
keep_alive_period_in_seconds (int): How long in seconds (default: None)
172+
will the resource including instances, volumes, ecr imges etc. used
173+
by this training job be kept alive for reuse for the next follow-up
174+
training job.
170175
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
171176
storing input and output data during training (default: 30).
172177
@@ -510,6 +515,7 @@ def __init__(
510515
self.role = role
511516
self.instance_count = instance_count
512517
self.instance_type = instance_type
518+
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
513519
self.instance_groups = instance_groups
514520
self.volume_size = volume_size
515521
self.volume_kms_key = volume_kms_key
@@ -1578,6 +1584,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
15781584
if "EnableNetworkIsolation" in job_details:
15791585
init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]
15801586

1587+
if "KeepAlivePeriodInSeconds" in job_details["ResourceConfig"]:
1588+
init_params["keep_alive_period_in_seconds"] = job_details["ResourceConfig"][
1589+
"keepAlivePeriodInSeconds"
1590+
]
1591+
15811592
has_hps = "HyperParameters" in job_details
15821593
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
15831594

@@ -2126,7 +2137,9 @@ def _is_local_channel(cls, input_uri):
21262137
return isinstance(input_uri, string_types) and input_uri.startswith("file://")
21272138

21282139
@classmethod
2129-
def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
2140+
def update(
2141+
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2142+
):
21302143
"""Update a running Amazon SageMaker training job.
21312144
21322145
Args:
@@ -2135,18 +2148,21 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
21352148
updated in the training job. (default: None).
21362149
profiler_config (dict): Configuration for how profiling information is emitted with
21372150
SageMaker Debugger. (default: None).
2151+
resource_config (dict): Configuration for resource of the training job. (default: None).
21382152
21392153
Returns:
21402154
sagemaker.estimator._TrainingJob: Constructed object that captures
21412155
all information about the updated training job.
21422156
"""
2143-
update_args = cls._get_update_args(estimator, profiler_rule_configs, profiler_config)
2157+
update_args = cls._get_update_args(
2158+
estimator, profiler_rule_configs, profiler_config, resource_config
2159+
)
21442160
estimator.sagemaker_session.update_training_job(**update_args)
21452161

21462162
return estimator.latest_training_job
21472163

21482164
@classmethod
2149-
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
2165+
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
21502166
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
21512167
21522168
Args:
@@ -2156,13 +2172,15 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
21562172
updated in the training job. (default: None).
21572173
profiler_config (dict): Configuration for how profiling information is emitted with
21582174
SageMaker Debugger. (default: None).
2175+
resource_config (dict): Configuration for resource of the training job. (default: None).
21592176
21602177
Returns:
21612178
Dict: dict for `sagemaker.session.Session.update_training_job` method
21622179
"""
21632180
update_args = {"job_name": estimator.latest_training_job.name}
21642181
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
21652182
update_args.update(build_dict("profiler_config", profiler_config))
2183+
update_args.update(build_dict("resource_config", resource_config))
21662184

21672185
return update_args
21682186

@@ -2218,6 +2236,7 @@ def __init__(
22182236
role: str,
22192237
instance_count: Optional[Union[int, PipelineVariable]] = None,
22202238
instance_type: Optional[Union[str, PipelineVariable]] = None,
2239+
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
22212240
volume_size: Union[int, PipelineVariable] = 30,
22222241
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
22232242
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
@@ -2270,6 +2289,10 @@ def __init__(
22702289
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
22712290
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
22722291
not set.
2292+
keep_alive_period_in_seconds (int): How long in seconds (default: None)
2293+
will the resource including instances, volumes, ecr imges etc. used
2294+
by this training job be kept alive for reuse for the next follow-up
2295+
training job.
22732296
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
22742297
storing input and output data during training (default: 30).
22752298
@@ -2591,6 +2614,7 @@ def __init__(
25912614
role,
25922615
instance_count,
25932616
instance_type,
2617+
keep_alive_period_in_seconds,
25942618
volume_size,
25952619
volume_kms_key,
25962620
max_run,

src/sagemaker/job.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
7878
estimator.instance_groups,
7979
estimator.volume_size,
8080
estimator.volume_kms_key,
81+
estimator.keep_alive_period_in_seconds,
8182
)
8283
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
8384
vpc_config = estimator.get_vpc_config()
@@ -281,14 +282,21 @@ def _prepare_output_config(s3_path, kms_key_id):
281282

282283
@staticmethod
283284
def _prepare_resource_config(
284-
instance_count, instance_type, instance_groups, volume_size, volume_kms_key
285+
instance_count,
286+
instance_type,
287+
instance_groups,
288+
volume_size,
289+
volume_kms_key,
290+
keep_alive_period_in_seconds,
285291
):
286292
"""Placeholder docstring"""
287293
resource_config = {
288294
"VolumeSizeInGB": volume_size,
289295
}
290296
if volume_kms_key is not None:
291297
resource_config["VolumeKmsKeyId"] = volume_kms_key
298+
if keep_alive_period_in_seconds is not None:
299+
resource_config["KeepAlivePeriodInSeconds"] = keep_alive_period_in_seconds
292300
if instance_groups is not None:
293301
if instance_count is not None or instance_type is not None:
294302
raise ValueError(

src/sagemaker/session.py

+8
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def update_training_job(
821821
job_name,
822822
profiler_rule_configs=None,
823823
profiler_config=None,
824+
resource_config=None,
824825
):
825826
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
826827
@@ -829,11 +830,13 @@ def update_training_job(
829830
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
830831
profiler_config(dict): Configuration for how profiling information is emitted with
831832
SageMaker Profiler. (default: ``None``).
833+
resource_config (dict): Configuration for resource of the training job. (default: None).
832834
"""
833835
update_training_job_request = self._get_update_training_job_request(
834836
job_name=job_name,
835837
profiler_rule_configs=profiler_rule_configs,
836838
profiler_config=profiler_config,
839+
resource_config=resource_config,
837840
)
838841
LOGGER.info("Updating training job with name %s", job_name)
839842
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
@@ -844,6 +847,7 @@ def _get_update_training_job_request(
844847
job_name,
845848
profiler_rule_configs=None,
846849
profiler_config=None,
850+
resource_config=None,
847851
):
848852
"""Constructs a request compatible for updateing an Amazon SageMaker training job.
849853
@@ -852,6 +856,7 @@ def _get_update_training_job_request(
852856
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
853857
profiler_config(dict): Configuration for how profiling information is emitted with
854858
SageMaker Profiler. (default: ``None``).
859+
resource_config (dict): Configuration for resource of the training job. (default: None).
855860
856861
Returns:
857862
Dict: an update training request dict
@@ -866,6 +871,9 @@ def _get_update_training_job_request(
866871
if profiler_config is not None:
867872
update_training_job_request["ProfilerConfig"] = profiler_config
868873

874+
if resource_config is not None:
875+
update_training_job_request["ResourceConfig"] = resource_config
876+
869877
return update_training_job_request
870878

871879
def process(

tests/unit/test_estimator.py

+18
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
BUCKET_NAME = "mybucket"
7070
INSTANCE_COUNT = 1
7171
INSTANCE_TYPE = "c4.4xlarge"
72+
KEEP_ALIVE_PERIOD_IN_SECONDS = 1800
7273
ACCELERATOR_TYPE = "ml.eia.medium"
7374
ROLE = "DummyRole"
7475
IMAGE_URI = "fakeimage"
@@ -351,6 +352,23 @@ def test_framework_with_heterogeneous_cluster(sagemaker_session):
351352
}
352353

353354

355+
def test_framework_with_keep_alive_period(sagemaker_session):
356+
f = DummyFramework(
357+
entry_point=SCRIPT_PATH,
358+
role=ROLE,
359+
sagemaker_session=sagemaker_session,
360+
instance_groups=[
361+
InstanceGroup("group1", "ml.c4.xlarge", 1),
362+
InstanceGroup("group2", "ml.m4.xlarge", 2),
363+
],
364+
keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS,
365+
)
366+
f.fit("s3://mydata")
367+
sagemaker_session.train.assert_called_once()
368+
_, args = sagemaker_session.train.call_args
369+
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS
370+
371+
354372
def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
355373
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
356374
base_config=rule_configs.stalled_training_rule(),

tests/unit/test_job.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
LOCAL_FILE_NAME = "file://local/file"
3030
INSTANCE_COUNT = 1
3131
INSTANCE_TYPE = "c4.4xlarge"
32+
KEEP_ALIVE_PERIOD = 1800
3233
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
3334
VOLUME_SIZE = 1
3435
MAX_RUNTIME = 1
@@ -599,7 +600,7 @@ def test_prepare_output_config_kms_key_none():
599600

600601
def test_prepare_resource_config():
601602
resource_config = _Job._prepare_resource_config(
602-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None
603+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None
603604
)
604605

605606
assert resource_config == {
@@ -609,9 +610,23 @@ def test_prepare_resource_config():
609610
}
610611

611612

613+
def test_prepare_resource_config_with_keep_alive_period():
614+
resource_config = _Job._prepare_resource_config(
615+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD
616+
)
617+
618+
assert resource_config == {
619+
"InstanceCount": INSTANCE_COUNT,
620+
"InstanceType": INSTANCE_TYPE,
621+
"VolumeSizeInGB": VOLUME_SIZE,
622+
"VolumeKmsKeyId": VOLUME_KMS_KEY,
623+
"KeepAlivePeriodInSeconds": KEEP_ALIVE_PERIOD,
624+
}
625+
626+
612627
def test_prepare_resource_config_with_volume_kms():
613628
resource_config = _Job._prepare_resource_config(
614-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY
629+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None
615630
)
616631

617632
assert resource_config == {
@@ -629,6 +644,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster():
629644
[InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2)],
630645
VOLUME_SIZE,
631646
None,
647+
None,
632648
)
633649

634650
assert resource_config == {
@@ -648,6 +664,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
648664
[INSTANCE_GROUP],
649665
VOLUME_SIZE,
650666
None,
667+
None,
651668
)
652669
assert "instance_count and instance_type cannot be set when instance_groups is set" in str(
653670
error
@@ -662,6 +679,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
662679
None,
663680
VOLUME_SIZE,
664681
None,
682+
None,
665683
)
666684
assert "instance_count and instance_type must be set if instance_groups is not set" in str(
667685
error

0 commit comments

Comments
 (0)