Skip to content

Commit e5e327b

Browse files
committed
black format
1 parent 7e04f33 commit e5e327b

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

tests/unit/test_chainer.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def test_additional_hyperparameters(sagemaker_session, chainer_version, chainer_
179179
)
180180

181181

182-
def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_version, chainer_py_version):
182+
def test_attach_with_additional_hyperparameters(
183+
sagemaker_session, chainer_version, chainer_py_version
184+
):
183185
training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}".format(
184186
chainer_version, chainer_py_version
185187
)
@@ -388,7 +390,9 @@ def test_model(sagemaker_session, chainer_version, chainer_py_version):
388390

389391

390392
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
391-
def test_model_prepare_container_def_accelerator_error(sagemaker_session, chainer_version, chainer_py_version):
393+
def test_model_prepare_container_def_accelerator_error(
394+
sagemaker_session, chainer_version, chainer_py_version
395+
):
392396
model = ChainerModel(
393397
MODEL_DATA,
394398
role=ROLE,
@@ -433,29 +437,44 @@ def test_train_image_default(sagemaker_session, chainer_version, chainer_py_vers
433437

434438
def test_train_image_cpu_instances(sagemaker_session, chainer_version, chainer_py_version):
435439
chainer = _chainer_estimator(
436-
sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, train_instance_type="ml.c2.2xlarge"
440+
sagemaker_session,
441+
framework_version=chainer_version,
442+
py_version=chainer_py_version,
443+
train_instance_type="ml.c2.2xlarge",
437444
)
438445
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version)
439446

440447
chainer = _chainer_estimator(
441-
sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, train_instance_type="ml.c4.2xlarge"
448+
sagemaker_session,
449+
framework_version=chainer_version,
450+
py_version=chainer_py_version,
451+
train_instance_type="ml.c4.2xlarge",
442452
)
443453
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version)
444454

445455
chainer = _chainer_estimator(
446-
sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, train_instance_type="ml.m16"
456+
sagemaker_session,
457+
framework_version=chainer_version,
458+
py_version=chainer_py_version,
459+
train_instance_type="ml.m16",
447460
)
448461
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version)
449462

450463

451464
def test_train_image_gpu_instances(sagemaker_session, chainer_version, chainer_py_version):
452465
chainer = _chainer_estimator(
453-
sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, train_instance_type="ml.g2.2xlarge"
466+
sagemaker_session,
467+
framework_version=chainer_version,
468+
py_version=chainer_py_version,
469+
train_instance_type="ml.g2.2xlarge",
454470
)
455471
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version, chainer_py_version)
456472

457473
chainer = _chainer_estimator(
458-
sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, train_instance_type="ml.p2.2xlarge"
474+
sagemaker_session,
475+
framework_version=chainer_version,
476+
py_version=chainer_py_version,
477+
train_instance_type="ml.p2.2xlarge",
459478
)
460479
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version, chainer_py_version)
461480

tests/unit/test_pytorch.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def _get_full_cpu_image_uri_with_ei(version, py_version):
9191

9292

9393
def _pytorch_estimator(
94-
sagemaker_session, framework_version, py_version, train_instance_type=None, base_job_name=None, **kwargs
94+
sagemaker_session,
95+
framework_version,
96+
py_version,
97+
train_instance_type=None,
98+
base_job_name=None,
99+
**kwargs
95100
):
96101
return PyTorch(
97102
entry_point=SCRIPT_PATH,
@@ -409,7 +414,9 @@ def test_train_image_cpu_instances(sagemaker_session, pytorch_version, pytorch_p
409414
)
410415
assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version, pytorch_py_version)
411416

412-
pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.m16")
417+
pytorch = _pytorch_estimator(
418+
sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.m16"
419+
)
413420
assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version, pytorch_py_version)
414421

415422

@@ -580,33 +587,45 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version):
580587

581588
def test_pt_enable_sm_metrics(sagemaker_session, pytorch_full_version):
582589
pytorch = _pytorch_estimator(
583-
sagemaker_session, framework_version=pytorch_full_version, py_version="py3", enable_sagemaker_metrics=True
590+
sagemaker_session,
591+
framework_version=pytorch_full_version,
592+
py_version="py3",
593+
enable_sagemaker_metrics=True,
584594
)
585595
assert pytorch.enable_sagemaker_metrics
586596

587597

588598
def test_pt_disable_sm_metrics(sagemaker_session, pytorch_full_version):
589599
pytorch = _pytorch_estimator(
590-
sagemaker_session, framework_version=pytorch_full_version, py_version="py3", enable_sagemaker_metrics=False
600+
sagemaker_session,
601+
framework_version=pytorch_full_version,
602+
py_version="py3",
603+
enable_sagemaker_metrics=False,
591604
)
592605
assert not pytorch.enable_sagemaker_metrics
593606

594607

595608
def test_pt_disable_sm_metrics_if_pt_ver_is_less_than_1_15(sagemaker_session):
596609
for fw_version in ["1.1", "1.2"]:
597-
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version, py_version="py3")
610+
pytorch = _pytorch_estimator(
611+
sagemaker_session, framework_version=fw_version, py_version="py3"
612+
)
598613
assert pytorch.enable_sagemaker_metrics is None
599614

600615

601616
def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
602617
for fw_version in ["1.3", "1.4", "2.0", "2.1"]:
603-
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version, py_version="py3")
618+
pytorch = _pytorch_estimator(
619+
sagemaker_session, framework_version=fw_version, py_version="py3"
620+
)
604621
assert pytorch.enable_sagemaker_metrics
605622

606623

607624
def test_custom_image_estimator_deploy(sagemaker_session, pytorch_full_version):
608625
custom_image = "mycustomimage:latest"
609-
pytorch = _pytorch_estimator(sagemaker_session, framework_version=pytorch_full_version, py_version="py3")
626+
pytorch = _pytorch_estimator(
627+
sagemaker_session, framework_version=pytorch_full_version, py_version="py3"
628+
)
610629
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
611630
model = pytorch.create_model(image=custom_image)
612631
assert model.image == custom_image

0 commit comments

Comments
 (0)