Skip to content

Commit 44e0f18

Browse files
author
Shibo Xing
committed
fix: format with black
1 parent d3a7822 commit 44e0f18

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

tests/conftest.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888

8989
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
9090

91+
9192
def pytest_addoption(parser):
9293
parser.addoption("--sagemaker-client-config", action="store", default=None)
9394
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
@@ -257,9 +258,13 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
257258

258259

259260
@pytest.fixture(scope="module")
260-
def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_version):
261+
def huggingface_training_compiler_pytorch_version(
262+
huggingface_training_compiler_version,
263+
):
261264
versions = _huggingface_base_fm_version(
262-
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
265+
huggingface_training_compiler_version,
266+
"pytorch",
267+
"huggingface_training_compiler",
263268
)
264269
if not versions:
265270
pytest.skip(
@@ -270,9 +275,13 @@ def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_
270275

271276

272277
@pytest.fixture(scope="module")
273-
def huggingface_training_compiler_tensorflow_version(huggingface_training_compiler_version):
278+
def huggingface_training_compiler_tensorflow_version(
279+
huggingface_training_compiler_version,
280+
):
274281
versions = _huggingface_base_fm_version(
275-
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
282+
huggingface_training_compiler_version,
283+
"tensorflow",
284+
"huggingface_training_compiler",
276285
)
277286
if not versions:
278287
pytest.skip(
@@ -294,19 +303,25 @@ def huggingface_training_compiler_tensorflow_py_version(
294303

295304

296305
@pytest.fixture(scope="module")
297-
def huggingface_training_compiler_pytorch_py_version(huggingface_training_compiler_pytorch_version):
306+
def huggingface_training_compiler_pytorch_py_version(
307+
huggingface_training_compiler_pytorch_version,
308+
):
298309
return "py38"
299310

300311

301312
@pytest.fixture(scope="module")
302-
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version):
313+
def huggingface_pytorch_latest_training_py_version(
314+
huggingface_training_pytorch_latest_version,
315+
):
303316
return (
304317
"py38" if Version(huggingface_training_pytorch_latest_version) >= Version("1.9") else "py36"
305318
)
306319

307320

308321
@pytest.fixture(scope="module")
309-
def huggingface_pytorch_latest_inference_py_version(huggingface_inference_pytorch_latest_version):
322+
def huggingface_pytorch_latest_inference_py_version(
323+
huggingface_inference_pytorch_latest_version,
324+
):
310325
return (
311326
"py38"
312327
if Version(huggingface_inference_pytorch_latest_version) >= Version("1.9")
@@ -482,7 +497,8 @@ def pytorch_ddp_py_version():
482497

483498

484499
@pytest.fixture(
485-
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
500+
scope="module",
501+
params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"],
486502
)
487503
def pytorch_ddp_framework_version(request):
488504
return request.param
@@ -515,6 +531,7 @@ def gpu_instance_type(sagemaker_session, request):
515531
else:
516532
return "ml.p3.2xlarge"
517533

534+
518535
@pytest.fixture()
519536
def gpu_pytorch_instance_type(sagemaker_session, request):
520537
if "pytorch_inference_version" in request.fixturenames:
@@ -531,6 +548,7 @@ def gpu_pytorch_instance_type(sagemaker_session, request):
531548
else:
532549
return "ml.p3.2xlarge"
533550

551+
534552
@pytest.fixture(scope="session")
535553
def gpu_instance_type_list(sagemaker_session, request):
536554
region = sagemaker_session.boto_session.region_name

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _test_image_uris(
7373

7474
TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS
7575
if framework == "pytorch" and Version(fw_version) >= Version("1.13"):
76-
'''Handle P2 deprecation'''
76+
"""Handle P2 deprecation"""
7777
TYPES_AND_PROCESSORS = RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS
7878

7979
for instance_type, processor in TYPES_AND_PROCESSORS:

tests/unit/test_pytorch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,12 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
302302
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
303303
@patch("time.time", return_value=TIME)
304304
def test_pytorch(
305-
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type
305+
time,
306+
name_from_base,
307+
sagemaker_session,
308+
pytorch_inference_version,
309+
pytorch_inference_py_version,
310+
gpu_pytorch_instance_type,
306311
):
307312
pytorch = PyTorch(
308313
entry_point=SCRIPT_PATH,
@@ -356,7 +361,12 @@ def test_pytorch(
356361

357362
@patch("sagemaker.utils.repack_model", MagicMock())
358363
@patch("sagemaker.utils.create_tar_file", MagicMock())
359-
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type):
364+
def test_model(
365+
sagemaker_session,
366+
pytorch_inference_version,
367+
pytorch_inference_py_version,
368+
gpu_pytorch_instance_type,
369+
):
360370
model = PyTorchModel(
361371
MODEL_DATA,
362372
role=ROLE,
@@ -429,7 +439,7 @@ def test_model_custom_serialization(
429439
sagemaker_session,
430440
pytorch_inference_version,
431441
pytorch_inference_py_version,
432-
gpu_pytorch_instance_type
442+
gpu_pytorch_instance_type,
433443
):
434444
model = PyTorchModel(
435445
MODEL_DATA,

0 commit comments

Comments
 (0)