Skip to content

Commit c4d3b9e

Browse files
marckarpmarckarp
authored andcommitted
feature: support checkpoint to be passed from estimator (#2849)
Co-authored-by: marckarp <[email protected]>
1 parent 01962bc commit c4d3b9e

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/sagemaker/workflow/airflow.py

+5
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
195195
if s3_operations:
196196
train_config["S3Operations"] = s3_operations
197197

198+
if (estimator.checkpoint_local_path is not None) & (estimator.checkpoint_s3_uri is not None):
199+
train_config["CheckpointConfig"] = {
200+
"LocalPath": estimator.checkpoint_local_path,
201+
"S3Uri": estimator.checkpoint_s3_uri,
202+
}
198203
return train_config
199204

200205

tests/unit/test_airflow.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
262262
py_version="py3",
263263
framework_version="1.15.2",
264264
role="{{ role }}",
265-
instance_count="{{ instance_count }}",
265+
instance_count=1,
266266
instance_type="ml.c4.2xlarge",
267267
volume_size="{{ volume_size }}",
268268
volume_kms_key="{{ volume_kms_key }}",
@@ -276,6 +276,8 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
276276
security_group_ids=["{{ security_group_ids }}"],
277277
metric_definitions=[{"Name": "{{ name }}", "Regex": "{{ regex }}"}],
278278
sagemaker_session=sagemaker_session,
279+
checkpoint_local_path="{{ checkpoint_local_path }}",
280+
checkpoint_s3_uri="{{ checkpoint_s3_uri }}",
279281
)
280282

281283
data = "{{ training_data }}"
@@ -294,7 +296,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
294296
"TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP,
295297
"StoppingCondition": {"MaxRuntimeInSeconds": "{{ max_run }}"},
296298
"ResourceConfig": {
297-
"InstanceCount": "{{ instance_count }}",
299+
"InstanceCount": 1,
298300
"InstanceType": "ml.c4.2xlarge",
299301
"VolumeSizeInGB": "{{ volume_size }}",
300302
"VolumeKmsKeyId": "{{ volume_kms_key }}",
@@ -338,6 +340,10 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
338340
}
339341
]
340342
},
343+
"CheckpointConfig": {
344+
"LocalPath": "{{ checkpoint_local_path }}",
345+
"S3Uri": "{{ checkpoint_s3_uri }}",
346+
},
341347
}
342348
assert config == expected_config
343349

0 commit comments

Comments
 (0)