Skip to content

Commit 4898bc1

Browse files
author
BruceZhang@eitug
committed
black reformat file that has changed
1 parent fedcb99 commit 4898bc1

File tree

2 files changed

+109
-50
lines changed

2 files changed

+109
-50
lines changed

tests/integ/test_training_compiler.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def skip_if_incompatible(gpu_instance_type, request):
7676
region = integ.test_region()
7777
if region not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS:
7878
pytest.skip("SageMaker Training Compiler is not supported in this region")
79-
if gpu_instance_type == "ml.p3.16xlarge" and region not in integ.DATA_PARALLEL_TESTING_REGIONS:
79+
if (
80+
gpu_instance_type == "ml.p3.16xlarge"
81+
and region not in integ.DATA_PARALLEL_TESTING_REGIONS
82+
):
8083
pytest.skip("Data parallel testing is not allowed in this region")
8184
if gpu_instance_type == "ml.p3.2xlarge" and region in integ.TRAINING_NO_P3_REGIONS:
8285
pytest.skip("no ml.p3 instances in this region")
@@ -124,7 +127,9 @@ def test_huggingface_pytorch(
124127
sagemaker_session=sagemaker_session,
125128
disable_profiler=True,
126129
compiler_config=HFTrainingCompilerConfig(),
127-
distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
130+
distribution={"pytorchxla": {"enabled": True}}
131+
if instance_count > 1
132+
else None,
128133
)
129134

130135
hf.fit(huggingface_dummy_dataset)
@@ -170,7 +175,9 @@ def test_pytorch(
170175
sagemaker_session=sagemaker_session,
171176
disable_profiler=True,
172177
compiler_config=PTTrainingCompilerConfig(),
173-
distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
178+
distribution={"pytorchxla": {"enabled": True}}
179+
if instance_count > 1
180+
else None,
174181
)
175182

176183
hf.fit(huggingface_dummy_dataset)
@@ -216,7 +223,10 @@ def test_huggingface_tensorflow(
216223

217224
@pytest.mark.release
218225
def test_tensorflow(
219-
sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set
226+
sagemaker_session,
227+
gpu_instance_type,
228+
tensorflow_training_latest_version,
229+
imagenet_val_set,
220230
):
221231
"""
222232
Test the TensorFlow estimator
@@ -247,7 +257,9 @@ def test_tensorflow(
247257
py_version="py39",
248258
git_config={
249259
"repo": "https://github.com/tensorflow/models.git",
250-
"branch": "v" + ".".join(tensorflow_training_latest_version.split(".")[:2]) + ".0",
260+
"branch": "v"
261+
+ ".".join(tensorflow_training_latest_version.split(".")[:2])
262+
+ ".0",
251263
},
252264
source_dir=".",
253265
entry_point="official/vision/train.py",

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

+92-45
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
REGION = "us-east-1"
4646
GPU = "ml.p3.2xlarge"
4747
SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"}
48-
UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
48+
UNSUPPORTED_GPU_INSTANCE_CLASSES = (
49+
EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
50+
)
4951

5052
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5153

@@ -96,9 +98,13 @@ def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
9698
)
9799

98100

99-
def _create_train_job(version, instance_type, training_compiler_config, instance_count=1):
101+
def _create_train_job(
102+
version, instance_type, training_compiler_config, instance_count=1
103+
):
100104
return {
101-
"image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config),
105+
"image_uri": _get_full_gpu_image_uri(
106+
version, instance_type, training_compiler_config
107+
),
102108
"input_mode": "File",
103109
"input_config": [
104110
{
@@ -183,7 +189,9 @@ def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_v
183189
).fit()
184190

185191

186-
@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES)
192+
@pytest.mark.parametrize(
193+
"unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES
194+
)
187195
def test_unsupported_gpu_instance(
188196
unsupported_gpu_instance_class, pytorch_training_compiler_version
189197
):
@@ -303,7 +311,12 @@ def test_unsupported_distribution(
303311
@patch("time.time", return_value=TIME)
304312
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
305313
def test_pytorchxla_distribution(
306-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class, pytorch_training_py_version
314+
time,
315+
name_from_base,
316+
sagemaker_session,
317+
pytorch_training_compiler_version,
318+
instance_class,
319+
pytorch_training_py_version,
307320
):
308321
if Version(pytorch_training_compiler_version) < Version("1.12"):
309322
pytest.skip("This test is intended for PyTorch 1.12 and above")
@@ -333,17 +346,24 @@ def test_pytorchxla_distribution(
333346
assert boto_call_names == ["resource"]
334347

335348
expected_train_args = _create_train_job(
336-
pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2
349+
pytorch_training_compiler_version,
350+
instance_type,
351+
compiler_config,
352+
instance_count=2,
337353
)
338-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
354+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
355+
"S3Uri"
356+
] = inputs
339357
expected_train_args["enable_sagemaker_metrics"] = False
340-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
358+
expected_train_args["hyperparameters"][
359+
TrainingCompilerConfig.HP_ENABLE_COMPILER
360+
] = json.dumps(True)
361+
expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(
341362
True
342363
)
343-
expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True)
344-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
345-
False
346-
)
364+
expected_train_args["hyperparameters"][
365+
TrainingCompilerConfig.HP_ENABLE_DEBUG
366+
] = json.dumps(False)
347367

348368
actual_train_args = sagemaker_session.method_calls[0][2]
349369
assert (
@@ -357,7 +377,12 @@ def test_pytorchxla_distribution(
357377
@patch("time.time", return_value=TIME)
358378
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
359379
def test_default_compiler_config(
360-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class, pytorch_training_py_version
380+
time,
381+
name_from_base,
382+
sagemaker_session,
383+
pytorch_training_compiler_version,
384+
instance_class,
385+
pytorch_training_py_version,
361386
):
362387
compiler_config = TrainingCompilerConfig()
363388
instance_type = f"ml.{instance_class}.xlarge"
@@ -386,14 +411,16 @@ def test_default_compiler_config(
386411
expected_train_args = _create_train_job(
387412
pytorch_training_compiler_version, instance_type, compiler_config
388413
)
389-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
414+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
415+
"S3Uri"
416+
] = inputs
390417
expected_train_args["enable_sagemaker_metrics"] = False
391-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
392-
True
393-
)
394-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
395-
False
396-
)
418+
expected_train_args["hyperparameters"][
419+
TrainingCompilerConfig.HP_ENABLE_COMPILER
420+
] = json.dumps(True)
421+
expected_train_args["hyperparameters"][
422+
TrainingCompilerConfig.HP_ENABLE_DEBUG
423+
] = json.dumps(False)
397424

398425
actual_train_args = sagemaker_session.method_calls[0][2]
399426
assert (
@@ -406,7 +433,11 @@ def test_default_compiler_config(
406433
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
407434
@patch("time.time", return_value=TIME)
408435
def test_debug_compiler_config(
409-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, pytorch_training_py_version
436+
time,
437+
name_from_base,
438+
sagemaker_session,
439+
pytorch_training_compiler_version,
440+
pytorch_training_py_version,
410441
):
411442
compiler_config = TrainingCompilerConfig(debug=True)
412443

@@ -434,14 +465,16 @@ def test_debug_compiler_config(
434465
expected_train_args = _create_train_job(
435466
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
436467
)
437-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
468+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
469+
"S3Uri"
470+
] = inputs
438471
expected_train_args["enable_sagemaker_metrics"] = False
439-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
440-
True
441-
)
442-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
443-
True
444-
)
472+
expected_train_args["hyperparameters"][
473+
TrainingCompilerConfig.HP_ENABLE_COMPILER
474+
] = json.dumps(True)
475+
expected_train_args["hyperparameters"][
476+
TrainingCompilerConfig.HP_ENABLE_DEBUG
477+
] = json.dumps(True)
445478

446479
actual_train_args = sagemaker_session.method_calls[0][2]
447480
assert (
@@ -454,7 +487,11 @@ def test_debug_compiler_config(
454487
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
455488
@patch("time.time", return_value=TIME)
456489
def test_disable_compiler_config(
457-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, pytorch_training_py_version
490+
time,
491+
name_from_base,
492+
sagemaker_session,
493+
pytorch_training_compiler_version,
494+
pytorch_training_py_version,
458495
):
459496
compiler_config = TrainingCompilerConfig(enabled=False)
460497

@@ -482,14 +519,16 @@ def test_disable_compiler_config(
482519
expected_train_args = _create_train_job(
483520
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
484521
)
485-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
522+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
523+
"S3Uri"
524+
] = inputs
486525
expected_train_args["enable_sagemaker_metrics"] = False
487-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
488-
False
489-
)
490-
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
491-
False
492-
)
526+
expected_train_args["hyperparameters"][
527+
TrainingCompilerConfig.HP_ENABLE_COMPILER
528+
] = json.dumps(False)
529+
expected_train_args["hyperparameters"][
530+
TrainingCompilerConfig.HP_ENABLE_DEBUG
531+
] = json.dumps(False)
493532

494533
actual_train_args = sagemaker_session.method_calls[0][2]
495534
assert (
@@ -508,7 +547,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
508547
"py38-cu113-ubuntu20.04"
509548
)
510549
returned_job_description = {
511-
"AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image},
550+
"AlgorithmSpecification": {
551+
"TrainingInputMode": "File",
552+
"TrainingImage": training_image,
553+
},
512554
"HyperParameters": {
513555
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
514556
"sagemaker_program": '"iris-dnn-classifier.py"',
@@ -530,14 +572,19 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
530572
"TrainingJobName": "trcomp",
531573
"TrainingJobStatus": "Completed",
532574
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp",
533-
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"},
575+
"OutputDataConfig": {
576+
"KmsKeyId": "",
577+
"S3OutputPath": "s3://place/output/trcomp",
578+
},
534579
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
535580
}
536581
sagemaker_session.sagemaker_client.describe_training_job = Mock(
537582
name="describe_training_job", return_value=returned_job_description
538583
)
539584

540-
estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session)
585+
estimator = PyTorch.attach(
586+
training_job_name="trcomp", sagemaker_session=sagemaker_session
587+
)
541588
assert estimator.latest_training_job.job_name == "trcomp"
542589
assert estimator.py_version == "py38"
543590
assert estimator.framework_version == "1.12.0"
@@ -549,12 +596,12 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
549596
assert estimator.output_path == "s3://place/output/trcomp"
550597
assert estimator.output_kms_key == ""
551598
assert estimator.hyperparameters()["training_steps"] == "100"
552-
assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps(
553-
compiler_enabled
554-
)
555-
assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps(
556-
debug_enabled
557-
)
599+
assert estimator.hyperparameters()[
600+
TrainingCompilerConfig.HP_ENABLE_COMPILER
601+
] == json.dumps(compiler_enabled)
602+
assert estimator.hyperparameters()[
603+
TrainingCompilerConfig.HP_ENABLE_DEBUG
604+
] == json.dumps(debug_enabled)
558605
assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
559606
assert estimator.entry_point == "iris-dnn-classifier.py"
560607

0 commit comments

Comments
 (0)