@@ -262,7 +262,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
262
262
py_version = "py3" ,
263
263
framework_version = "1.15.2" ,
264
264
role = "{{ role }}" ,
265
- instance_count = "{{ instance_count }}" ,
265
+ instance_count = 1 ,
266
266
instance_type = "ml.c4.2xlarge" ,
267
267
volume_size = "{{ volume_size }}" ,
268
268
volume_kms_key = "{{ volume_kms_key }}" ,
@@ -276,6 +276,8 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
276
276
security_group_ids = ["{{ security_group_ids }}" ],
277
277
metric_definitions = [{"Name" : "{{ name }}" , "Regex" : "{{ regex }}" }],
278
278
sagemaker_session = sagemaker_session ,
279
+ checkpoint_local_path = "{{ checkpoint_local_path }}" ,
280
+ checkpoint_s3_uri = "{{ checkpoint_s3_uri }}" ,
279
281
)
280
282
281
283
data = "{{ training_data }}"
@@ -294,7 +296,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
294
296
"TrainingJobName" : "{{ base_job_name }}-%s" % TIME_STAMP ,
295
297
"StoppingCondition" : {"MaxRuntimeInSeconds" : "{{ max_run }}" },
296
298
"ResourceConfig" : {
297
- "InstanceCount" : "{{ instance_count }}" ,
299
+ "InstanceCount" : 1 ,
298
300
"InstanceType" : "ml.c4.2xlarge" ,
299
301
"VolumeSizeInGB" : "{{ volume_size }}" ,
300
302
"VolumeKmsKeyId" : "{{ volume_kms_key }}" ,
@@ -338,6 +340,10 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
338
340
}
339
341
]
340
342
},
343
+ "CheckpointConfig" : {
344
+ "LocalPath" : "{{ checkpoint_local_path }}" ,
345
+ "S3Uri" : "{{ checkpoint_s3_uri }}" ,
346
+ },
341
347
}
342
348
assert config == expected_config
343
349
0 commit comments