Skip to content

Commit d3a7822

Browse files
author
Shibo Xing
committed
fix: p2 error in unit test with a fixture
1 parent 32f37d1 commit d3a7822

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"huggingface_training_compiler",
8787
)
8888

89+
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
8990

9091
def pytest_addoption(parser):
9192
parser.addoption("--sagemaker-client-config", action="store", default=None)
@@ -514,6 +515,21 @@ def gpu_instance_type(sagemaker_session, request):
514515
else:
515516
return "ml.p3.2xlarge"
516517

518+
@pytest.fixture()
519+
def gpu_pytorch_instance_type(sagemaker_session, request):
520+
if "pytorch_inference_version" in request.fixturenames:
521+
fw_version = request.getfixturevalue("pytorch_inference_version")
522+
else:
523+
fw_version = request.param
524+
525+
region = sagemaker_session.boto_session.region_name
526+
if region in NO_P3_REGIONS:
527+
if Version(fw_version) >= Version("1.13"):
528+
return PYTORCH_RENEWED_GPU
529+
else:
530+
return "ml.p2.xlarge"
531+
else:
532+
return "ml.p3.2xlarge"
517533

518534
@pytest.fixture(scope="session")
519535
def gpu_instance_type_list(sagemaker_session, request):

tests/unit/test_pytorch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
302302
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
303303
@patch("time.time", return_value=TIME)
304304
def test_pytorch(
305-
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version
305+
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type
306306
):
307307
pytorch = PyTorch(
308308
entry_point=SCRIPT_PATH,
@@ -339,24 +339,24 @@ def test_pytorch(
339339
REGION,
340340
version=pytorch_inference_version,
341341
py_version=pytorch_inference_py_version,
342-
instance_type=GPU,
342+
instance_type=gpu_pytorch_instance_type,
343343
image_scope="inference",
344344
)
345345

346-
actual_environment = model.prepare_container_def(GPU)
346+
actual_environment = model.prepare_container_def(gpu_pytorch_instance_type)
347347
submit_directory = actual_environment["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"]
348348
model_url = actual_environment["ModelDataUrl"]
349349
expected_environment = _get_environment(submit_directory, model_url, expected_image_uri)
350350
assert actual_environment == expected_environment
351351

352352
assert "cpu" in model.prepare_container_def(CPU)["Image"]
353-
predictor = pytorch.deploy(1, GPU)
353+
predictor = pytorch.deploy(1, gpu_pytorch_instance_type)
354354
assert isinstance(predictor, PyTorchPredictor)
355355

356356

357357
@patch("sagemaker.utils.repack_model", MagicMock())
358358
@patch("sagemaker.utils.create_tar_file", MagicMock())
359-
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version):
359+
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type):
360360
model = PyTorchModel(
361361
MODEL_DATA,
362362
role=ROLE,
@@ -365,21 +365,22 @@ def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_p
365365
py_version=pytorch_inference_py_version,
366366
sagemaker_session=sagemaker_session,
367367
)
368-
predictor = model.deploy(1, GPU)
368+
predictor = model.deploy(1, gpu_pytorch_instance_type)
369369
assert isinstance(predictor, PyTorchPredictor)
370370

371371

372372
@patch("sagemaker.utils.create_tar_file", MagicMock())
373373
@patch("sagemaker.utils.repack_model")
374-
def test_mms_model(repack_model, sagemaker_session):
374+
@pytest.mark.parametrize("gpu_pytorch_instance_type", ["1.2"], indirect=True)
375+
def test_mms_model(repack_model, sagemaker_session, gpu_pytorch_instance_type):
375376
PyTorchModel(
376377
MODEL_DATA,
377378
role=ROLE,
378379
entry_point=SCRIPT_PATH,
379380
sagemaker_session=sagemaker_session,
380381
framework_version="1.2",
381382
py_version="py3",
382-
).deploy(1, GPU)
383+
).deploy(1, gpu_pytorch_instance_type)
383384

384385
repack_model.assert_called_with(
385386
dependencies=[],
@@ -428,6 +429,7 @@ def test_model_custom_serialization(
428429
sagemaker_session,
429430
pytorch_inference_version,
430431
pytorch_inference_py_version,
432+
gpu_pytorch_instance_type
431433
):
432434
model = PyTorchModel(
433435
MODEL_DATA,
@@ -441,7 +443,7 @@ def test_model_custom_serialization(
441443
custom_deserializer = Mock()
442444
predictor = model.deploy(
443445
1,
444-
GPU,
446+
gpu_pytorch_instance_type,
445447
serializer=custom_serializer,
446448
deserializer=custom_deserializer,
447449
)

0 commit comments

Comments
 (0)