Skip to content

Commit 8d10e24

Browse files
committed
Fixing unit tests
1 parent 3acf917 commit 8d10e24

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/sagemaker/pytorch/estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
423423
)
424424
image_uri = init_params.pop("image_uri")
425425
framework, py_version, tag, _ = framework_name_from_image(image_uri)
426-
framework = framework.split("-")[0]
426+
if framework:
427+
framework = framework.split("-")[0]
427428

428429
if tag is None:
429430
framework_version = None

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def test_unsupported_gpu_instance(
202202
).fit()
203203

204204

205+
@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.")
205206
def test_unsupported_framework_version():
206207
with pytest.raises(ValueError):
207208
PyTorch(
@@ -560,7 +561,9 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
560561
assert estimator.entry_point == "iris-dnn-classifier.py"
561562

562563

563-
def test_register_hf_pytorch_model_auto_infer_framework(
564+
@patch("sagemaker.utils.repack_model", MagicMock())
565+
@patch("sagemaker.utils.create_tar_file", MagicMock())
566+
def test_register_pytorch_model_auto_infer_framework(
564567
sagemaker_session, pytorch_training_compiler_version
565568
):
566569

@@ -574,6 +577,7 @@ def test_register_hf_pytorch_model_auto_infer_framework(
574577
pt_model = PyTorchModel(
575578
model_data="s3://some/data.tar.gz",
576579
role=ROLE,
580+
entry_point=SCRIPT_PATH,
577581
framework_version=pytorch_training_compiler_version,
578582
py_version="py38",
579583
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)