Skip to content

Commit 30d0d3b

Browse files
Merge branch 'master' into master
2 parents 4914e2f + 0493c85 commit 30d0d3b

File tree

6 files changed

+129
-8
lines changed

6 files changed

+129
-8
lines changed

src/sagemaker/estimator.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
role: str,
117117
instance_count: Optional[Union[int, PipelineVariable]] = None,
118118
instance_type: Optional[Union[str, PipelineVariable]] = None,
119+
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
119120
volume_size: Union[int, PipelineVariable] = 30,
120121
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
121122
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
@@ -167,6 +168,9 @@ def __init__(
167168
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
168169
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
169170
not set.
171+
keep_alive_period_in_seconds (int): The duration of time in seconds
172+
to retain configured resources in a warm pool for subsequent
173+
training jobs (default: None).
170174
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
171175
storing input and output data during training (default: 30).
172176
@@ -510,6 +514,7 @@ def __init__(
510514
self.role = role
511515
self.instance_count = instance_count
512516
self.instance_type = instance_type
517+
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
513518
self.instance_groups = instance_groups
514519
self.volume_size = volume_size
515520
self.volume_kms_key = volume_kms_key
@@ -1578,6 +1583,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
15781583
if "EnableNetworkIsolation" in job_details:
15791584
init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]
15801585

1586+
if "KeepAlivePeriodInSeconds" in job_details["ResourceConfig"]:
1587+
init_params["keep_alive_period_in_seconds"] = job_details["ResourceConfig"][
1588+
"keepAlivePeriodInSeconds"
1589+
]
1590+
15811591
has_hps = "HyperParameters" in job_details
15821592
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
15831593

@@ -2126,7 +2136,9 @@ def _is_local_channel(cls, input_uri):
21262136
return isinstance(input_uri, string_types) and input_uri.startswith("file://")
21272137

21282138
@classmethod
2129-
def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
2139+
def update(
2140+
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2141+
):
21302142
"""Update a running Amazon SageMaker training job.
21312143
21322144
Args:
@@ -2135,18 +2147,23 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
21352147
updated in the training job. (default: None).
21362148
profiler_config (dict): Configuration for how profiling information is emitted with
21372149
SageMaker Debugger. (default: None).
2150+
resource_config (dict): Configuration of the resources for the training job. You can
2151+
update the keep-alive period if the warm pool status is `Available`. No other fields
2152+
can be updated. (default: None).
21382153
21392154
Returns:
21402155
sagemaker.estimator._TrainingJob: Constructed object that captures
21412156
all information about the updated training job.
21422157
"""
2143-
update_args = cls._get_update_args(estimator, profiler_rule_configs, profiler_config)
2158+
update_args = cls._get_update_args(
2159+
estimator, profiler_rule_configs, profiler_config, resource_config
2160+
)
21442161
estimator.sagemaker_session.update_training_job(**update_args)
21452162

21462163
return estimator.latest_training_job
21472164

21482165
@classmethod
2149-
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
2166+
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
21502167
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
21512168
21522169
Args:
@@ -2156,13 +2173,17 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
21562173
updated in the training job. (default: None).
21572174
profiler_config (dict): Configuration for how profiling information is emitted with
21582175
SageMaker Debugger. (default: None).
2176+
resource_config (dict): Configuration of the resources for the training job. You can
2177+
update the keep-alive period if the warm pool status is `Available`. No other fields
2178+
can be updated. (default: None).
21592179
21602180
Returns:
21612181
Dict: dict for `sagemaker.session.Session.update_training_job` method
21622182
"""
21632183
update_args = {"job_name": estimator.latest_training_job.name}
21642184
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
21652185
update_args.update(build_dict("profiler_config", profiler_config))
2186+
update_args.update(build_dict("resource_config", resource_config))
21662187

21672188
return update_args
21682189

@@ -2218,6 +2239,7 @@ def __init__(
22182239
role: str,
22192240
instance_count: Optional[Union[int, PipelineVariable]] = None,
22202241
instance_type: Optional[Union[str, PipelineVariable]] = None,
2242+
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
22212243
volume_size: Union[int, PipelineVariable] = 30,
22222244
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
22232245
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
@@ -2270,6 +2292,9 @@ def __init__(
22702292
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
22712293
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
22722294
not set.
2295+
keep_alive_period_in_seconds (int): The duration of time in seconds
2296+
to retain configured resources in a warm pool for subsequent
2297+
training jobs (default: None).
22732298
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
22742299
storing input and output data during training (default: 30).
22752300
@@ -2591,6 +2616,7 @@ def __init__(
25912616
role,
25922617
instance_count,
25932618
instance_type,
2619+
keep_alive_period_in_seconds,
25942620
volume_size,
25952621
volume_kms_key,
25962622
max_run,

src/sagemaker/job.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
7878
estimator.instance_groups,
7979
estimator.volume_size,
8080
estimator.volume_kms_key,
81+
estimator.keep_alive_period_in_seconds,
8182
)
8283
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
8384
vpc_config = estimator.get_vpc_config()
@@ -281,14 +282,21 @@ def _prepare_output_config(s3_path, kms_key_id):
281282

282283
@staticmethod
283284
def _prepare_resource_config(
284-
instance_count, instance_type, instance_groups, volume_size, volume_kms_key
285+
instance_count,
286+
instance_type,
287+
instance_groups,
288+
volume_size,
289+
volume_kms_key,
290+
keep_alive_period_in_seconds,
285291
):
286292
"""Placeholder docstring"""
287293
resource_config = {
288294
"VolumeSizeInGB": volume_size,
289295
}
290296
if volume_kms_key is not None:
291297
resource_config["VolumeKmsKeyId"] = volume_kms_key
298+
if keep_alive_period_in_seconds is not None:
299+
resource_config["KeepAlivePeriodInSeconds"] = keep_alive_period_in_seconds
292300
if instance_groups is not None:
293301
if instance_count is not None or instance_type is not None:
294302
raise ValueError(

src/sagemaker/session.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def update_training_job(
821821
job_name,
822822
profiler_rule_configs=None,
823823
profiler_config=None,
824+
resource_config=None,
824825
):
825826
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
826827
@@ -829,11 +830,15 @@ def update_training_job(
829830
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
830831
profiler_config(dict): Configuration for how profiling information is emitted with
831832
SageMaker Profiler. (default: ``None``).
833+
resource_config (dict): Configuration of the resources for the training job. You can
834+
update the keep-alive period if the warm pool status is `Available`. No other fields
835+
can be updated. (default: ``None``).
832836
"""
833837
update_training_job_request = self._get_update_training_job_request(
834838
job_name=job_name,
835839
profiler_rule_configs=profiler_rule_configs,
836840
profiler_config=profiler_config,
841+
resource_config=resource_config,
837842
)
838843
LOGGER.info("Updating training job with name %s", job_name)
839844
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
@@ -844,14 +849,18 @@ def _get_update_training_job_request(
844849
job_name,
845850
profiler_rule_configs=None,
846851
profiler_config=None,
852+
resource_config=None,
847853
):
848-
"""Constructs a request compatible for updateing an Amazon SageMaker training job.
854+
"""Constructs a request compatible for updating an Amazon SageMaker training job.
849855
850856
Args:
851857
job_name (str): Name of the training job being updated.
852858
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
853859
profiler_config(dict): Configuration for how profiling information is emitted with
854860
SageMaker Profiler. (default: ``None``).
861+
resource_config (dict): Configuration of the resources for the training job. You can
862+
update the keep-alive period if the warm pool status is `Available`. No other fields
863+
can be updated. (default: ``None``).
855864
856865
Returns:
857866
Dict: an update training request dict
@@ -866,6 +875,9 @@ def _get_update_training_job_request(
866875
if profiler_config is not None:
867876
update_training_job_request["ProfilerConfig"] = profiler_config
868877

878+
if resource_config is not None:
879+
update_training_job_request["ResourceConfig"] = resource_config
880+
869881
return update_training_job_request
870882

871883
def process(

tests/integ/test_training_compiler.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def skip_if_incompatible(gpu_instance_type, request):
8181
pytest.skip("no ml.p3 instances in this region")
8282

8383

84-
@pytest.mark.release
8584
@pytest.mark.parametrize(
8685
"gpu_instance_type,instance_count",
8786
[
@@ -130,6 +129,46 @@ def test_huggingface_pytorch(
130129
hf.fit(huggingface_dummy_dataset)
131130

132131

132+
@pytest.mark.release
133+
def test_huggingface_pytorch_release(
134+
sagemaker_session,
135+
gpu_instance_type,
136+
huggingface_training_compiler_latest_version,
137+
huggingface_training_compiler_pytorch_latest_version,
138+
huggingface_dummy_dataset,
139+
):
140+
"""
141+
Test the HuggingFace estimator with PyTorch
142+
"""
143+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
144+
data_path = os.path.join(DATA_DIR, "huggingface")
145+
146+
hf = HuggingFace(
147+
py_version="py38",
148+
entry_point=os.path.join(data_path, "run_glue.py"),
149+
role="SageMakerRole",
150+
transformers_version=huggingface_training_compiler_latest_version,
151+
pytorch_version=huggingface_training_compiler_pytorch_latest_version,
152+
instance_count=1,
153+
instance_type=gpu_instance_type,
154+
hyperparameters={
155+
"model_name_or_path": "distilbert-base-cased",
156+
"task_name": "wnli",
157+
"do_train": True,
158+
"do_eval": True,
159+
"max_seq_length": 128,
160+
"fp16": True,
161+
"per_device_train_batch_size": 128,
162+
"output_dir": "/opt/ml/model",
163+
},
164+
sagemaker_session=sagemaker_session,
165+
disable_profiler=True,
166+
compiler_config=HFTrainingCompilerConfig(),
167+
)
168+
169+
hf.fit(huggingface_dummy_dataset)
170+
171+
133172
@pytest.mark.release
134173
def test_huggingface_tensorflow(
135174
sagemaker_session,

tests/unit/test_estimator.py

+18
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
BUCKET_NAME = "mybucket"
7070
INSTANCE_COUNT = 1
7171
INSTANCE_TYPE = "c4.4xlarge"
72+
KEEP_ALIVE_PERIOD_IN_SECONDS = 1800
7273
ACCELERATOR_TYPE = "ml.eia.medium"
7374
ROLE = "DummyRole"
7475
IMAGE_URI = "fakeimage"
@@ -351,6 +352,23 @@ def test_framework_with_heterogeneous_cluster(sagemaker_session):
351352
}
352353

353354

355+
def test_framework_with_keep_alive_period(sagemaker_session):
356+
f = DummyFramework(
357+
entry_point=SCRIPT_PATH,
358+
role=ROLE,
359+
sagemaker_session=sagemaker_session,
360+
instance_groups=[
361+
InstanceGroup("group1", "ml.c4.xlarge", 1),
362+
InstanceGroup("group2", "ml.m4.xlarge", 2),
363+
],
364+
keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS,
365+
)
366+
f.fit("s3://mydata")
367+
sagemaker_session.train.assert_called_once()
368+
_, args = sagemaker_session.train.call_args
369+
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS
370+
371+
354372
def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
355373
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
356374
base_config=rule_configs.stalled_training_rule(),

tests/unit/test_job.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
LOCAL_FILE_NAME = "file://local/file"
3030
INSTANCE_COUNT = 1
3131
INSTANCE_TYPE = "c4.4xlarge"
32+
KEEP_ALIVE_PERIOD = 1800
3233
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
3334
VOLUME_SIZE = 1
3435
MAX_RUNTIME = 1
@@ -599,7 +600,7 @@ def test_prepare_output_config_kms_key_none():
599600

600601
def test_prepare_resource_config():
601602
resource_config = _Job._prepare_resource_config(
602-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None
603+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None
603604
)
604605

605606
assert resource_config == {
@@ -609,9 +610,23 @@ def test_prepare_resource_config():
609610
}
610611

611612

613+
def test_prepare_resource_config_with_keep_alive_period():
614+
resource_config = _Job._prepare_resource_config(
615+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD
616+
)
617+
618+
assert resource_config == {
619+
"InstanceCount": INSTANCE_COUNT,
620+
"InstanceType": INSTANCE_TYPE,
621+
"VolumeSizeInGB": VOLUME_SIZE,
622+
"VolumeKmsKeyId": VOLUME_KMS_KEY,
623+
"KeepAlivePeriodInSeconds": KEEP_ALIVE_PERIOD,
624+
}
625+
626+
612627
def test_prepare_resource_config_with_volume_kms():
613628
resource_config = _Job._prepare_resource_config(
614-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY
629+
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None
615630
)
616631

617632
assert resource_config == {
@@ -629,6 +644,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster():
629644
[InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2)],
630645
VOLUME_SIZE,
631646
None,
647+
None,
632648
)
633649

634650
assert resource_config == {
@@ -648,6 +664,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
648664
[INSTANCE_GROUP],
649665
VOLUME_SIZE,
650666
None,
667+
None,
651668
)
652669
assert "instance_count and instance_type cannot be set when instance_groups is set" in str(
653670
error
@@ -662,6 +679,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
662679
None,
663680
VOLUME_SIZE,
664681
None,
682+
None,
665683
)
666684
assert "instance_count and instance_type must be set if instance_groups is not set" in str(
667685
error

0 commit comments

Comments
 (0)